services: restructure services into folders
This commit is contained in:
20
CHANGELOG.md
20
CHANGELOG.md
@@ -47,6 +47,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- It is now possible to tell whether `UserStartedSpeakingFrame` or
|
||||
`UserStoppedSpeakingFrame` have been generated because of emulation frames.
|
||||
|
||||
### Changed
|
||||
|
||||
- Pipecat services have been reorganized into packages. Each package can have
|
||||
one or more of the following modules (in the future new module names might be
|
||||
needed) depending on the services implemented:
|
||||
- image: for image generation services
|
||||
- llm: for LLM services
|
||||
- memory: for memory services
|
||||
- stt: for Speech-To-Text services
|
||||
- tts: for Text-To-Speech services
|
||||
- video: for video generation services
|
||||
- vision: for video recognition services
|
||||
|
||||
### Deprecated
|
||||
|
||||
- All Pipecat services imports have been deprecated and a warning will be shown
|
||||
when using the old import. The new import should be
|
||||
`pipecat.services.[service].[image,llm,memory,stt,tts,video,vision]`. For
|
||||
example, `from pipecat.services.openai.llm import OpenAILLMService`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue that would cause `SegmentedSTTService` based services
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def _warn_deprecated_access(globals: Dict[str, Any], attr, old: str, new: str):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
f"Module `pipecat.services.{old}` is deprecated, use `pipecat.services.{new}` instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
return globals[attr]
|
||||
|
||||
|
||||
class DeprecatedModuleProxy:
|
||||
def __init__(self, globals: Dict[str, Any], old: str, new: str):
|
||||
self._globals = globals
|
||||
self._old = old
|
||||
self._new = new
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self._globals:
|
||||
return _warn_deprecated_access(self._globals, attr, self._old, self._new)
|
||||
raise AttributeError(f"module 'pipecat.{self._old}' has no attribute '{attr}'")
|
||||
|
||||
13
src/pipecat/services/anthropic/__init__.py
Normal file
13
src/pipecat/services/anthropic/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "anthropic", "anthropic.llm")
|
||||
13
src/pipecat/services/assemblyai/__init__.py
Normal file
13
src/pipecat/services/assemblyai/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .stt import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "assemblyai", "assemblyai.stt")
|
||||
@@ -27,9 +27,7 @@ try:
|
||||
from assemblyai import AudioEncoding
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`. Also, set `ASSEMBLYAI_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/aws/__init__.py
Normal file
13
src/pipecat/services/aws/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
|
||||
@@ -1,813 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from openai import AsyncAzureOpenAI
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
|
||||
from pipecat.services.openai import (
|
||||
OpenAILLMService,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
# See .env.example for Azure configuration needed
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
CancellationReason,
|
||||
ResultReason,
|
||||
ServicePropertyChannel,
|
||||
SpeechConfig,
|
||||
SpeechRecognizer,
|
||||
SpeechSynthesisOutputFormat,
|
||||
SpeechSynthesizer,
|
||||
)
|
||||
from azure.cognitiveservices.speech.audio import (
|
||||
AudioStreamFormat,
|
||||
PushAudioInputStream,
|
||||
)
|
||||
from azure.cognitiveservices.speech.dialog import AudioConfig
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_azure_language(language: Language) -> Optional[str]:
|
||||
language_map = {
|
||||
# Afrikaans
|
||||
Language.AF: "af-ZA",
|
||||
Language.AF_ZA: "af-ZA",
|
||||
# Amharic
|
||||
Language.AM: "am-ET",
|
||||
Language.AM_ET: "am-ET",
|
||||
# Arabic
|
||||
Language.AR: "ar-AE", # Default to UAE Arabic
|
||||
Language.AR_AE: "ar-AE",
|
||||
Language.AR_BH: "ar-BH",
|
||||
Language.AR_DZ: "ar-DZ",
|
||||
Language.AR_EG: "ar-EG",
|
||||
Language.AR_IQ: "ar-IQ",
|
||||
Language.AR_JO: "ar-JO",
|
||||
Language.AR_KW: "ar-KW",
|
||||
Language.AR_LB: "ar-LB",
|
||||
Language.AR_LY: "ar-LY",
|
||||
Language.AR_MA: "ar-MA",
|
||||
Language.AR_OM: "ar-OM",
|
||||
Language.AR_QA: "ar-QA",
|
||||
Language.AR_SA: "ar-SA",
|
||||
Language.AR_SY: "ar-SY",
|
||||
Language.AR_TN: "ar-TN",
|
||||
Language.AR_YE: "ar-YE",
|
||||
# Assamese
|
||||
Language.AS: "as-IN",
|
||||
Language.AS_IN: "as-IN",
|
||||
# Azerbaijani
|
||||
Language.AZ: "az-AZ",
|
||||
Language.AZ_AZ: "az-AZ",
|
||||
# Bulgarian
|
||||
Language.BG: "bg-BG",
|
||||
Language.BG_BG: "bg-BG",
|
||||
# Bengali
|
||||
Language.BN: "bn-IN", # Default to Indian Bengali
|
||||
Language.BN_BD: "bn-BD",
|
||||
Language.BN_IN: "bn-IN",
|
||||
# Bosnian
|
||||
Language.BS: "bs-BA",
|
||||
Language.BS_BA: "bs-BA",
|
||||
# Catalan
|
||||
Language.CA: "ca-ES",
|
||||
Language.CA_ES: "ca-ES",
|
||||
# Czech
|
||||
Language.CS: "cs-CZ",
|
||||
Language.CS_CZ: "cs-CZ",
|
||||
# Welsh
|
||||
Language.CY: "cy-GB",
|
||||
Language.CY_GB: "cy-GB",
|
||||
# Danish
|
||||
Language.DA: "da-DK",
|
||||
Language.DA_DK: "da-DK",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_AT: "de-AT",
|
||||
Language.DE_CH: "de-CH",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Greek
|
||||
Language.EL: "el-GR",
|
||||
Language.EL_GR: "el-GR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US English
|
||||
Language.EN_AU: "en-AU",
|
||||
Language.EN_CA: "en-CA",
|
||||
Language.EN_GB: "en-GB",
|
||||
Language.EN_HK: "en-HK",
|
||||
Language.EN_IE: "en-IE",
|
||||
Language.EN_IN: "en-IN",
|
||||
Language.EN_KE: "en-KE",
|
||||
Language.EN_NG: "en-NG",
|
||||
Language.EN_NZ: "en-NZ",
|
||||
Language.EN_PH: "en-PH",
|
||||
Language.EN_SG: "en-SG",
|
||||
Language.EN_TZ: "en-TZ",
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_ZA: "en-ZA",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain Spanish
|
||||
Language.ES_AR: "es-AR",
|
||||
Language.ES_BO: "es-BO",
|
||||
Language.ES_CL: "es-CL",
|
||||
Language.ES_CO: "es-CO",
|
||||
Language.ES_CR: "es-CR",
|
||||
Language.ES_CU: "es-CU",
|
||||
Language.ES_DO: "es-DO",
|
||||
Language.ES_EC: "es-EC",
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_GQ: "es-GQ",
|
||||
Language.ES_GT: "es-GT",
|
||||
Language.ES_HN: "es-HN",
|
||||
Language.ES_MX: "es-MX",
|
||||
Language.ES_NI: "es-NI",
|
||||
Language.ES_PA: "es-PA",
|
||||
Language.ES_PE: "es-PE",
|
||||
Language.ES_PR: "es-PR",
|
||||
Language.ES_PY: "es-PY",
|
||||
Language.ES_SV: "es-SV",
|
||||
Language.ES_US: "es-US",
|
||||
Language.ES_UY: "es-UY",
|
||||
Language.ES_VE: "es-VE",
|
||||
# Estonian
|
||||
Language.ET: "et-EE",
|
||||
Language.ET_EE: "et-EE",
|
||||
# Basque
|
||||
Language.EU: "eu-ES",
|
||||
Language.EU_ES: "eu-ES",
|
||||
# Persian
|
||||
Language.FA: "fa-IR",
|
||||
Language.FA_IR: "fa-IR",
|
||||
# Finnish
|
||||
Language.FI: "fi-FI",
|
||||
Language.FI_FI: "fi-FI",
|
||||
# Filipino
|
||||
Language.FIL: "fil-PH",
|
||||
Language.FIL_PH: "fil-PH",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_BE: "fr-BE",
|
||||
Language.FR_CA: "fr-CA",
|
||||
Language.FR_CH: "fr-CH",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# Irish
|
||||
Language.GA: "ga-IE",
|
||||
Language.GA_IE: "ga-IE",
|
||||
# Galician
|
||||
Language.GL: "gl-ES",
|
||||
Language.GL_ES: "gl-ES",
|
||||
# Gujarati
|
||||
Language.GU: "gu-IN",
|
||||
Language.GU_IN: "gu-IN",
|
||||
# Hebrew
|
||||
Language.HE: "he-IL",
|
||||
Language.HE_IL: "he-IL",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Croatian
|
||||
Language.HR: "hr-HR",
|
||||
Language.HR_HR: "hr-HR",
|
||||
# Hungarian
|
||||
Language.HU: "hu-HU",
|
||||
Language.HU_HU: "hu-HU",
|
||||
# Armenian
|
||||
Language.HY: "hy-AM",
|
||||
Language.HY_AM: "hy-AM",
|
||||
# Indonesian
|
||||
Language.ID: "id-ID",
|
||||
Language.ID_ID: "id-ID",
|
||||
# Icelandic
|
||||
Language.IS: "is-IS",
|
||||
Language.IS_IS: "is-IS",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Inuktitut
|
||||
Language.IU_CANS_CA: "iu-Cans-CA",
|
||||
Language.IU_LATN_CA: "iu-Latn-CA",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Javanese
|
||||
Language.JV: "jv-ID",
|
||||
Language.JV_ID: "jv-ID",
|
||||
# Georgian
|
||||
Language.KA: "ka-GE",
|
||||
Language.KA_GE: "ka-GE",
|
||||
# Kazakh
|
||||
Language.KK: "kk-KZ",
|
||||
Language.KK_KZ: "kk-KZ",
|
||||
# Khmer
|
||||
Language.KM: "km-KH",
|
||||
Language.KM_KH: "km-KH",
|
||||
# Kannada
|
||||
Language.KN: "kn-IN",
|
||||
Language.KN_IN: "kn-IN",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Lao
|
||||
Language.LO: "lo-LA",
|
||||
Language.LO_LA: "lo-LA",
|
||||
# Lithuanian
|
||||
Language.LT: "lt-LT",
|
||||
Language.LT_LT: "lt-LT",
|
||||
# Latvian
|
||||
Language.LV: "lv-LV",
|
||||
Language.LV_LV: "lv-LV",
|
||||
# Macedonian
|
||||
Language.MK: "mk-MK",
|
||||
Language.MK_MK: "mk-MK",
|
||||
# Malayalam
|
||||
Language.ML: "ml-IN",
|
||||
Language.ML_IN: "ml-IN",
|
||||
# Mongolian
|
||||
Language.MN: "mn-MN",
|
||||
Language.MN_MN: "mn-MN",
|
||||
# Marathi
|
||||
Language.MR: "mr-IN",
|
||||
Language.MR_IN: "mr-IN",
|
||||
# Malay
|
||||
Language.MS: "ms-MY",
|
||||
Language.MS_MY: "ms-MY",
|
||||
# Maltese
|
||||
Language.MT: "mt-MT",
|
||||
Language.MT_MT: "mt-MT",
|
||||
# Burmese
|
||||
Language.MY: "my-MM",
|
||||
Language.MY_MM: "my-MM",
|
||||
# Norwegian
|
||||
Language.NB: "nb-NO",
|
||||
Language.NB_NO: "nb-NO",
|
||||
Language.NO: "nb-NO",
|
||||
# Nepali
|
||||
Language.NE: "ne-NP",
|
||||
Language.NE_NP: "ne-NP",
|
||||
# Dutch
|
||||
Language.NL: "nl-NL",
|
||||
Language.NL_BE: "nl-BE",
|
||||
Language.NL_NL: "nl-NL",
|
||||
# Odia
|
||||
Language.OR: "or-IN",
|
||||
Language.OR_IN: "or-IN",
|
||||
# Punjabi
|
||||
Language.PA: "pa-IN",
|
||||
Language.PA_IN: "pa-IN",
|
||||
# Polish
|
||||
Language.PL: "pl-PL",
|
||||
Language.PL_PL: "pl-PL",
|
||||
# Pashto
|
||||
Language.PS: "ps-AF",
|
||||
Language.PS_AF: "ps-AF",
|
||||
# Portuguese
|
||||
Language.PT: "pt-PT",
|
||||
Language.PT_BR: "pt-BR",
|
||||
Language.PT_PT: "pt-PT",
|
||||
# Romanian
|
||||
Language.RO: "ro-RO",
|
||||
Language.RO_RO: "ro-RO",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Sinhala
|
||||
Language.SI: "si-LK",
|
||||
Language.SI_LK: "si-LK",
|
||||
# Slovak
|
||||
Language.SK: "sk-SK",
|
||||
Language.SK_SK: "sk-SK",
|
||||
# Slovenian
|
||||
Language.SL: "sl-SI",
|
||||
Language.SL_SI: "sl-SI",
|
||||
# Somali
|
||||
Language.SO: "so-SO",
|
||||
Language.SO_SO: "so-SO",
|
||||
# Albanian
|
||||
Language.SQ: "sq-AL",
|
||||
Language.SQ_AL: "sq-AL",
|
||||
# Serbian
|
||||
Language.SR: "sr-RS",
|
||||
Language.SR_RS: "sr-RS",
|
||||
Language.SR_LATN: "sr-Latn-RS",
|
||||
Language.SR_LATN_RS: "sr-Latn-RS",
|
||||
# Sundanese
|
||||
Language.SU: "su-ID",
|
||||
Language.SU_ID: "su-ID",
|
||||
# Swedish
|
||||
Language.SV: "sv-SE",
|
||||
Language.SV_SE: "sv-SE",
|
||||
# Swahili
|
||||
Language.SW: "sw-KE",
|
||||
Language.SW_KE: "sw-KE",
|
||||
Language.SW_TZ: "sw-TZ",
|
||||
# Tamil
|
||||
Language.TA: "ta-IN",
|
||||
Language.TA_IN: "ta-IN",
|
||||
Language.TA_LK: "ta-LK",
|
||||
Language.TA_MY: "ta-MY",
|
||||
Language.TA_SG: "ta-SG",
|
||||
# Telugu
|
||||
Language.TE: "te-IN",
|
||||
Language.TE_IN: "te-IN",
|
||||
# Thai
|
||||
Language.TH: "th-TH",
|
||||
Language.TH_TH: "th-TH",
|
||||
# Turkish
|
||||
Language.TR: "tr-TR",
|
||||
Language.TR_TR: "tr-TR",
|
||||
# Ukrainian
|
||||
Language.UK: "uk-UA",
|
||||
Language.UK_UA: "uk-UA",
|
||||
# Urdu
|
||||
Language.UR: "ur-IN",
|
||||
Language.UR_IN: "ur-IN",
|
||||
Language.UR_PK: "ur-PK",
|
||||
# Uzbek
|
||||
Language.UZ: "uz-UZ",
|
||||
Language.UZ_UZ: "uz-UZ",
|
||||
# Vietnamese
|
||||
Language.VI: "vi-VN",
|
||||
Language.VI_VN: "vi-VN",
|
||||
# Wu Chinese
|
||||
Language.WUU: "wuu-CN",
|
||||
Language.WUU_CN: "wuu-CN",
|
||||
# Yue Chinese
|
||||
Language.YUE: "yue-CN",
|
||||
Language.YUE_CN: "yue-CN",
|
||||
# Chinese
|
||||
Language.ZH: "zh-CN",
|
||||
Language.ZH_CN: "zh-CN",
|
||||
Language.ZH_CN_GUANGXI: "zh-CN-guangxi",
|
||||
Language.ZH_CN_HENAN: "zh-CN-henan",
|
||||
Language.ZH_CN_LIAONING: "zh-CN-liaoning",
|
||||
Language.ZH_CN_SHAANXI: "zh-CN-shaanxi",
|
||||
Language.ZH_CN_SHANDONG: "zh-CN-shandong",
|
||||
Language.ZH_CN_SICHUAN: "zh-CN-sichuan",
|
||||
Language.ZH_HK: "zh-HK",
|
||||
Language.ZH_TW: "zh-TW",
|
||||
# Zulu
|
||||
Language.ZU: "zu-ZA",
|
||||
Language.ZU_ZA: "zu-ZA",
|
||||
}
|
||||
return language_map.get(language)
|
||||
|
||||
|
||||
def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
|
||||
sample_rate_map = {
|
||||
8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
|
||||
16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
|
||||
22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
|
||||
24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
|
||||
44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
|
||||
48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
|
||||
}
|
||||
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
|
||||
|
||||
|
||||
class AzureLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Azure OpenAI using the OpenAI-compatible interface.
|
||||
|
||||
This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while
|
||||
maintaining full compatibility with OpenAI's interface and functionality.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing Azure OpenAI
|
||||
endpoint (str): The Azure endpoint URL
|
||||
model (str): The model identifier to use
|
||||
api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_version: str = "2024-09-01-preview",
|
||||
**kwargs,
|
||||
):
|
||||
# Initialize variables before calling parent __init__() because that
|
||||
# will call create_client() and we need those values there.
|
||||
self._endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
super().__init__(api_key=api_key, model=model, **kwargs)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
"""Create OpenAI-compatible client for Azure OpenAI endpoint."""
|
||||
logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}")
|
||||
return AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=self._endpoint,
|
||||
api_version=self._api_version,
|
||||
)
|
||||
|
||||
|
||||
class AzureBaseTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._voice_id = voice
|
||||
self._speech_synthesizer = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
return language_to_azure_language(language)
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = self._settings["language"]
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
|
||||
class AzureTTSService(AzureBaseTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
self._audio_queue = asyncio.Queue()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._speech_config:
|
||||
return
|
||||
|
||||
# Now self.sample_rate is properly initialized
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=self._api_key,
|
||||
region=self._region,
|
||||
speech_recognition_language=self._settings["language"],
|
||||
)
|
||||
self._speech_config.set_speech_synthesis_output_format(
|
||||
sample_rate_to_output_format(self.sample_rate)
|
||||
)
|
||||
self._speech_config.set_service_property(
|
||||
"synthesizer.synthesis.connection.synthesisConnectionImpl",
|
||||
"websocket",
|
||||
ServicePropertyChannel.UriQueryParameter,
|
||||
)
|
||||
|
||||
self._speech_synthesizer = SpeechSynthesizer(
|
||||
speech_config=self._speech_config, audio_config=None
|
||||
)
|
||||
|
||||
# Set up event handlers
|
||||
self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
|
||||
self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
|
||||
self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
|
||||
|
||||
def _handle_synthesizing(self, evt):
|
||||
"""Handle audio chunks as they arrive"""
|
||||
if evt.result and evt.result.audio_data:
|
||||
self._audio_queue.put_nowait(evt.result.audio_data)
|
||||
|
||||
def _handle_completed(self, evt):
|
||||
"""Handle synthesis completion"""
|
||||
self._audio_queue.put_nowait(None) # Signal completion
|
||||
|
||||
def _handle_canceled(self, evt):
|
||||
"""Handle synthesis cancellation"""
|
||||
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def flush_audio(self):
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Stream audio chunks as they arrive
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if chunk is None: # End of stream
|
||||
break
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._speech_config:
|
||||
return
|
||||
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=self._api_key,
|
||||
region=self._region,
|
||||
speech_recognition_language=self._settings["language"],
|
||||
)
|
||||
self._speech_config.set_speech_synthesis_output_format(
|
||||
sample_rate_to_output_format(self.sample_rate)
|
||||
)
|
||||
self._speech_synthesizer = SpeechSynthesizer(
|
||||
speech_config=self._speech_config, audio_config=None
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
|
||||
result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
|
||||
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:],
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield TTSStoppedFrame()
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
|
||||
|
||||
class AzureSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
language: Language = Language.EN_US,
|
||||
sample_rate: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=api_key,
|
||||
region=region,
|
||||
speech_recognition_language=language_to_azure_language(language),
|
||||
)
|
||||
|
||||
self._audio_stream = None
|
||||
self._speech_recognizer = None
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_processing_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
await self.stop_processing_metrics()
|
||||
yield None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._audio_stream:
|
||||
return
|
||||
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
if self._speech_recognizer:
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
|
||||
if self._speech_recognizer:
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
|
||||
def _on_handle_recognized(self, event):
|
||||
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
|
||||
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_version="2023-06-01-preview",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._api_key = api_key
|
||||
self._azure_endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
self.set_model_name(model)
|
||||
self._image_size = image_size
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
|
||||
|
||||
headers = {"api-key": self._api_key, "Content-Type": "application/json"}
|
||||
|
||||
body = {
|
||||
# Enter your prompt text here
|
||||
"prompt": prompt,
|
||||
"size": self._image_size,
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
async with self._aiohttp_session.post(url, headers=headers, json=body) as submission:
|
||||
# We never get past this line, because this header isn't
|
||||
# defined on a 429 response, but something is eating our
|
||||
# exceptions!
|
||||
operation_location = submission.headers["operation-location"]
|
||||
status = ""
|
||||
attempts_left = 120
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
logger.error(f"{self} error: image generation timed out")
|
||||
yield ErrorFrame("Image generation timed out")
|
||||
return
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
response = await self._aiohttp_session.get(operation_location, headers=headers)
|
||||
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
# Load the image from the url
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
frame = URLImageRawFrame(
|
||||
url=image_url, image=image.tobytes(), size=image.size, format=image.format
|
||||
)
|
||||
yield frame
|
||||
15
src/pipecat/services/azure/__init__.py
Normal file
15
src/pipecat/services/azure/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "azure", "azure.[llm,stt,tts]")
|
||||
336
src/pipecat/services/azure/common.py
Normal file
336
src/pipecat/services/azure/common.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
def language_to_azure_language(language: Language) -> Optional[str]:
|
||||
language_map = {
|
||||
# Afrikaans
|
||||
Language.AF: "af-ZA",
|
||||
Language.AF_ZA: "af-ZA",
|
||||
# Amharic
|
||||
Language.AM: "am-ET",
|
||||
Language.AM_ET: "am-ET",
|
||||
# Arabic
|
||||
Language.AR: "ar-AE", # Default to UAE Arabic
|
||||
Language.AR_AE: "ar-AE",
|
||||
Language.AR_BH: "ar-BH",
|
||||
Language.AR_DZ: "ar-DZ",
|
||||
Language.AR_EG: "ar-EG",
|
||||
Language.AR_IQ: "ar-IQ",
|
||||
Language.AR_JO: "ar-JO",
|
||||
Language.AR_KW: "ar-KW",
|
||||
Language.AR_LB: "ar-LB",
|
||||
Language.AR_LY: "ar-LY",
|
||||
Language.AR_MA: "ar-MA",
|
||||
Language.AR_OM: "ar-OM",
|
||||
Language.AR_QA: "ar-QA",
|
||||
Language.AR_SA: "ar-SA",
|
||||
Language.AR_SY: "ar-SY",
|
||||
Language.AR_TN: "ar-TN",
|
||||
Language.AR_YE: "ar-YE",
|
||||
# Assamese
|
||||
Language.AS: "as-IN",
|
||||
Language.AS_IN: "as-IN",
|
||||
# Azerbaijani
|
||||
Language.AZ: "az-AZ",
|
||||
Language.AZ_AZ: "az-AZ",
|
||||
# Bulgarian
|
||||
Language.BG: "bg-BG",
|
||||
Language.BG_BG: "bg-BG",
|
||||
# Bengali
|
||||
Language.BN: "bn-IN", # Default to Indian Bengali
|
||||
Language.BN_BD: "bn-BD",
|
||||
Language.BN_IN: "bn-IN",
|
||||
# Bosnian
|
||||
Language.BS: "bs-BA",
|
||||
Language.BS_BA: "bs-BA",
|
||||
# Catalan
|
||||
Language.CA: "ca-ES",
|
||||
Language.CA_ES: "ca-ES",
|
||||
# Czech
|
||||
Language.CS: "cs-CZ",
|
||||
Language.CS_CZ: "cs-CZ",
|
||||
# Welsh
|
||||
Language.CY: "cy-GB",
|
||||
Language.CY_GB: "cy-GB",
|
||||
# Danish
|
||||
Language.DA: "da-DK",
|
||||
Language.DA_DK: "da-DK",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_AT: "de-AT",
|
||||
Language.DE_CH: "de-CH",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Greek
|
||||
Language.EL: "el-GR",
|
||||
Language.EL_GR: "el-GR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US English
|
||||
Language.EN_AU: "en-AU",
|
||||
Language.EN_CA: "en-CA",
|
||||
Language.EN_GB: "en-GB",
|
||||
Language.EN_HK: "en-HK",
|
||||
Language.EN_IE: "en-IE",
|
||||
Language.EN_IN: "en-IN",
|
||||
Language.EN_KE: "en-KE",
|
||||
Language.EN_NG: "en-NG",
|
||||
Language.EN_NZ: "en-NZ",
|
||||
Language.EN_PH: "en-PH",
|
||||
Language.EN_SG: "en-SG",
|
||||
Language.EN_TZ: "en-TZ",
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_ZA: "en-ZA",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain Spanish
|
||||
Language.ES_AR: "es-AR",
|
||||
Language.ES_BO: "es-BO",
|
||||
Language.ES_CL: "es-CL",
|
||||
Language.ES_CO: "es-CO",
|
||||
Language.ES_CR: "es-CR",
|
||||
Language.ES_CU: "es-CU",
|
||||
Language.ES_DO: "es-DO",
|
||||
Language.ES_EC: "es-EC",
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_GQ: "es-GQ",
|
||||
Language.ES_GT: "es-GT",
|
||||
Language.ES_HN: "es-HN",
|
||||
Language.ES_MX: "es-MX",
|
||||
Language.ES_NI: "es-NI",
|
||||
Language.ES_PA: "es-PA",
|
||||
Language.ES_PE: "es-PE",
|
||||
Language.ES_PR: "es-PR",
|
||||
Language.ES_PY: "es-PY",
|
||||
Language.ES_SV: "es-SV",
|
||||
Language.ES_US: "es-US",
|
||||
Language.ES_UY: "es-UY",
|
||||
Language.ES_VE: "es-VE",
|
||||
# Estonian
|
||||
Language.ET: "et-EE",
|
||||
Language.ET_EE: "et-EE",
|
||||
# Basque
|
||||
Language.EU: "eu-ES",
|
||||
Language.EU_ES: "eu-ES",
|
||||
# Persian
|
||||
Language.FA: "fa-IR",
|
||||
Language.FA_IR: "fa-IR",
|
||||
# Finnish
|
||||
Language.FI: "fi-FI",
|
||||
Language.FI_FI: "fi-FI",
|
||||
# Filipino
|
||||
Language.FIL: "fil-PH",
|
||||
Language.FIL_PH: "fil-PH",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_BE: "fr-BE",
|
||||
Language.FR_CA: "fr-CA",
|
||||
Language.FR_CH: "fr-CH",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# Irish
|
||||
Language.GA: "ga-IE",
|
||||
Language.GA_IE: "ga-IE",
|
||||
# Galician
|
||||
Language.GL: "gl-ES",
|
||||
Language.GL_ES: "gl-ES",
|
||||
# Gujarati
|
||||
Language.GU: "gu-IN",
|
||||
Language.GU_IN: "gu-IN",
|
||||
# Hebrew
|
||||
Language.HE: "he-IL",
|
||||
Language.HE_IL: "he-IL",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Croatian
|
||||
Language.HR: "hr-HR",
|
||||
Language.HR_HR: "hr-HR",
|
||||
# Hungarian
|
||||
Language.HU: "hu-HU",
|
||||
Language.HU_HU: "hu-HU",
|
||||
# Armenian
|
||||
Language.HY: "hy-AM",
|
||||
Language.HY_AM: "hy-AM",
|
||||
# Indonesian
|
||||
Language.ID: "id-ID",
|
||||
Language.ID_ID: "id-ID",
|
||||
# Icelandic
|
||||
Language.IS: "is-IS",
|
||||
Language.IS_IS: "is-IS",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Inuktitut
|
||||
Language.IU_CANS_CA: "iu-Cans-CA",
|
||||
Language.IU_LATN_CA: "iu-Latn-CA",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Javanese
|
||||
Language.JV: "jv-ID",
|
||||
Language.JV_ID: "jv-ID",
|
||||
# Georgian
|
||||
Language.KA: "ka-GE",
|
||||
Language.KA_GE: "ka-GE",
|
||||
# Kazakh
|
||||
Language.KK: "kk-KZ",
|
||||
Language.KK_KZ: "kk-KZ",
|
||||
# Khmer
|
||||
Language.KM: "km-KH",
|
||||
Language.KM_KH: "km-KH",
|
||||
# Kannada
|
||||
Language.KN: "kn-IN",
|
||||
Language.KN_IN: "kn-IN",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Lao
|
||||
Language.LO: "lo-LA",
|
||||
Language.LO_LA: "lo-LA",
|
||||
# Lithuanian
|
||||
Language.LT: "lt-LT",
|
||||
Language.LT_LT: "lt-LT",
|
||||
# Latvian
|
||||
Language.LV: "lv-LV",
|
||||
Language.LV_LV: "lv-LV",
|
||||
# Macedonian
|
||||
Language.MK: "mk-MK",
|
||||
Language.MK_MK: "mk-MK",
|
||||
# Malayalam
|
||||
Language.ML: "ml-IN",
|
||||
Language.ML_IN: "ml-IN",
|
||||
# Mongolian
|
||||
Language.MN: "mn-MN",
|
||||
Language.MN_MN: "mn-MN",
|
||||
# Marathi
|
||||
Language.MR: "mr-IN",
|
||||
Language.MR_IN: "mr-IN",
|
||||
# Malay
|
||||
Language.MS: "ms-MY",
|
||||
Language.MS_MY: "ms-MY",
|
||||
# Maltese
|
||||
Language.MT: "mt-MT",
|
||||
Language.MT_MT: "mt-MT",
|
||||
# Burmese
|
||||
Language.MY: "my-MM",
|
||||
Language.MY_MM: "my-MM",
|
||||
# Norwegian
|
||||
Language.NB: "nb-NO",
|
||||
Language.NB_NO: "nb-NO",
|
||||
Language.NO: "nb-NO",
|
||||
# Nepali
|
||||
Language.NE: "ne-NP",
|
||||
Language.NE_NP: "ne-NP",
|
||||
# Dutch
|
||||
Language.NL: "nl-NL",
|
||||
Language.NL_BE: "nl-BE",
|
||||
Language.NL_NL: "nl-NL",
|
||||
# Odia
|
||||
Language.OR: "or-IN",
|
||||
Language.OR_IN: "or-IN",
|
||||
# Punjabi
|
||||
Language.PA: "pa-IN",
|
||||
Language.PA_IN: "pa-IN",
|
||||
# Polish
|
||||
Language.PL: "pl-PL",
|
||||
Language.PL_PL: "pl-PL",
|
||||
# Pashto
|
||||
Language.PS: "ps-AF",
|
||||
Language.PS_AF: "ps-AF",
|
||||
# Portuguese
|
||||
Language.PT: "pt-PT",
|
||||
Language.PT_BR: "pt-BR",
|
||||
Language.PT_PT: "pt-PT",
|
||||
# Romanian
|
||||
Language.RO: "ro-RO",
|
||||
Language.RO_RO: "ro-RO",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Sinhala
|
||||
Language.SI: "si-LK",
|
||||
Language.SI_LK: "si-LK",
|
||||
# Slovak
|
||||
Language.SK: "sk-SK",
|
||||
Language.SK_SK: "sk-SK",
|
||||
# Slovenian
|
||||
Language.SL: "sl-SI",
|
||||
Language.SL_SI: "sl-SI",
|
||||
# Somali
|
||||
Language.SO: "so-SO",
|
||||
Language.SO_SO: "so-SO",
|
||||
# Albanian
|
||||
Language.SQ: "sq-AL",
|
||||
Language.SQ_AL: "sq-AL",
|
||||
# Serbian
|
||||
Language.SR: "sr-RS",
|
||||
Language.SR_RS: "sr-RS",
|
||||
Language.SR_LATN: "sr-Latn-RS",
|
||||
Language.SR_LATN_RS: "sr-Latn-RS",
|
||||
# Sundanese
|
||||
Language.SU: "su-ID",
|
||||
Language.SU_ID: "su-ID",
|
||||
# Swedish
|
||||
Language.SV: "sv-SE",
|
||||
Language.SV_SE: "sv-SE",
|
||||
# Swahili
|
||||
Language.SW: "sw-KE",
|
||||
Language.SW_KE: "sw-KE",
|
||||
Language.SW_TZ: "sw-TZ",
|
||||
# Tamil
|
||||
Language.TA: "ta-IN",
|
||||
Language.TA_IN: "ta-IN",
|
||||
Language.TA_LK: "ta-LK",
|
||||
Language.TA_MY: "ta-MY",
|
||||
Language.TA_SG: "ta-SG",
|
||||
# Telugu
|
||||
Language.TE: "te-IN",
|
||||
Language.TE_IN: "te-IN",
|
||||
# Thai
|
||||
Language.TH: "th-TH",
|
||||
Language.TH_TH: "th-TH",
|
||||
# Turkish
|
||||
Language.TR: "tr-TR",
|
||||
Language.TR_TR: "tr-TR",
|
||||
# Ukrainian
|
||||
Language.UK: "uk-UA",
|
||||
Language.UK_UA: "uk-UA",
|
||||
# Urdu
|
||||
Language.UR: "ur-IN",
|
||||
Language.UR_IN: "ur-IN",
|
||||
Language.UR_PK: "ur-PK",
|
||||
# Uzbek
|
||||
Language.UZ: "uz-UZ",
|
||||
Language.UZ_UZ: "uz-UZ",
|
||||
# Vietnamese
|
||||
Language.VI: "vi-VN",
|
||||
Language.VI_VN: "vi-VN",
|
||||
# Wu Chinese
|
||||
Language.WUU: "wuu-CN",
|
||||
Language.WUU_CN: "wuu-CN",
|
||||
# Yue Chinese
|
||||
Language.YUE: "yue-CN",
|
||||
Language.YUE_CN: "yue-CN",
|
||||
# Chinese
|
||||
Language.ZH: "zh-CN",
|
||||
Language.ZH_CN: "zh-CN",
|
||||
Language.ZH_CN_GUANGXI: "zh-CN-guangxi",
|
||||
Language.ZH_CN_HENAN: "zh-CN-henan",
|
||||
Language.ZH_CN_LIAONING: "zh-CN-liaoning",
|
||||
Language.ZH_CN_SHAANXI: "zh-CN-shaanxi",
|
||||
Language.ZH_CN_SHANDONG: "zh-CN-shandong",
|
||||
Language.ZH_CN_SICHUAN: "zh-CN-sichuan",
|
||||
Language.ZH_HK: "zh-HK",
|
||||
Language.ZH_TW: "zh-TW",
|
||||
# Zulu
|
||||
Language.ZU: "zu-ZA",
|
||||
Language.ZU_ZA: "zu-ZA",
|
||||
}
|
||||
return language_map.get(language)
|
||||
86
src/pipecat/services/azure/image.py
Normal file
86
src/pipecat/services/azure/image.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
from pipecat.services.ai_services import ImageGenService
|
||||
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
api_version="2023-06-01-preview",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._api_key = api_key
|
||||
self._azure_endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
self.set_model_name(model)
|
||||
self._image_size = image_size
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
|
||||
|
||||
headers = {"api-key": self._api_key, "Content-Type": "application/json"}
|
||||
|
||||
body = {
|
||||
# Enter your prompt text here
|
||||
"prompt": prompt,
|
||||
"size": self._image_size,
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
async with self._aiohttp_session.post(url, headers=headers, json=body) as submission:
|
||||
# We never get past this line, because this header isn't
|
||||
# defined on a 429 response, but something is eating our
|
||||
# exceptions!
|
||||
operation_location = submission.headers["operation-location"]
|
||||
status = ""
|
||||
attempts_left = 120
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
logger.error(f"{self} error: image generation timed out")
|
||||
yield ErrorFrame("Image generation timed out")
|
||||
return
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
response = await self._aiohttp_session.get(operation_location, headers=headers)
|
||||
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
# Load the image from the url
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
frame = URLImageRawFrame(
|
||||
url=image_url, image=image.tobytes(), size=image.size, format=image.format
|
||||
)
|
||||
yield frame
|
||||
49
src/pipecat/services/azure/llm.py
Normal file
49
src/pipecat/services/azure/llm.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class AzureLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Azure OpenAI using the OpenAI-compatible interface.
|
||||
|
||||
This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while
|
||||
maintaining full compatibility with OpenAI's interface and functionality.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing Azure OpenAI
|
||||
endpoint (str): The Azure endpoint URL
|
||||
model (str): The model identifier to use
|
||||
api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
model: str,
|
||||
api_version: str = "2024-09-01-preview",
|
||||
**kwargs,
|
||||
):
|
||||
# Initialize variables before calling parent __init__() because that
|
||||
# will call create_client() and we need those values there.
|
||||
self._endpoint = endpoint
|
||||
self._api_version = api_version
|
||||
super().__init__(api_key=api_key, model=model, **kwargs)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
"""Create OpenAI-compatible client for Azure OpenAI endpoint."""
|
||||
logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}")
|
||||
return AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=self._endpoint,
|
||||
api_version=self._api_version,
|
||||
)
|
||||
107
src/pipecat/services/azure/stt.py
Normal file
107
src/pipecat/services/azure/stt.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
ResultReason,
|
||||
SpeechConfig,
|
||||
SpeechRecognizer,
|
||||
)
|
||||
from azure.cognitiveservices.speech.audio import (
|
||||
AudioStreamFormat,
|
||||
PushAudioInputStream,
|
||||
)
|
||||
from azure.cognitiveservices.speech.dialog import AudioConfig
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class AzureSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
language: Language = Language.EN_US,
|
||||
sample_rate: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=api_key,
|
||||
region=region,
|
||||
speech_recognition_language=language_to_azure_language(language),
|
||||
)
|
||||
|
||||
self._audio_stream = None
|
||||
self._speech_recognizer = None
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
await self.start_processing_metrics()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.write(audio)
|
||||
await self.stop_processing_metrics()
|
||||
yield None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._audio_stream:
|
||||
return
|
||||
|
||||
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
|
||||
self._audio_stream = PushAudioInputStream(stream_format)
|
||||
|
||||
audio_config = AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._speech_recognizer = SpeechRecognizer(
|
||||
speech_config=self._speech_config, audio_config=audio_config
|
||||
)
|
||||
self._speech_recognizer.recognized.connect(self._on_handle_recognized)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
|
||||
if self._speech_recognizer:
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
|
||||
if self._speech_recognizer:
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
|
||||
def _on_handle_recognized(self, event):
|
||||
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
|
||||
frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
290
src/pipecat/services/azure/tts.py
Normal file
290
src/pipecat/services/azure/tts.py
Normal file
@@ -0,0 +1,290 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.azure.common import language_to_azure_language
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
CancellationReason,
|
||||
ResultReason,
|
||||
ServicePropertyChannel,
|
||||
SpeechConfig,
|
||||
SpeechSynthesisOutputFormat,
|
||||
SpeechSynthesizer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
|
||||
sample_rate_map = {
|
||||
8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
|
||||
16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
|
||||
22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
|
||||
24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
|
||||
44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
|
||||
48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
|
||||
}
|
||||
return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
|
||||
|
||||
|
||||
class AzureBaseTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self._api_key = api_key
|
||||
self._region = region
|
||||
self._voice_id = voice
|
||||
self._speech_synthesizer = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
return language_to_azure_language(language)
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = self._settings["language"]
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
|
||||
class AzureTTSService(AzureBaseTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
self._audio_queue = asyncio.Queue()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._speech_config:
|
||||
return
|
||||
|
||||
# Now self.sample_rate is properly initialized
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=self._api_key,
|
||||
region=self._region,
|
||||
speech_recognition_language=self._settings["language"],
|
||||
)
|
||||
self._speech_config.set_speech_synthesis_output_format(
|
||||
sample_rate_to_output_format(self.sample_rate)
|
||||
)
|
||||
self._speech_config.set_service_property(
|
||||
"synthesizer.synthesis.connection.synthesisConnectionImpl",
|
||||
"websocket",
|
||||
ServicePropertyChannel.UriQueryParameter,
|
||||
)
|
||||
|
||||
self._speech_synthesizer = SpeechSynthesizer(
|
||||
speech_config=self._speech_config, audio_config=None
|
||||
)
|
||||
|
||||
# Set up event handlers
|
||||
self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
|
||||
self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
|
||||
self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
|
||||
|
||||
def _handle_synthesizing(self, evt):
|
||||
"""Handle audio chunks as they arrive"""
|
||||
if evt.result and evt.result.audio_data:
|
||||
self._audio_queue.put_nowait(evt.result.audio_data)
|
||||
|
||||
def _handle_completed(self, evt):
|
||||
"""Handle synthesis completion"""
|
||||
self._audio_queue.put_nowait(None) # Signal completion
|
||||
|
||||
def _handle_canceled(self, evt):
|
||||
"""Handle synthesis cancellation"""
|
||||
logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def flush_audio(self):
|
||||
logger.trace(f"{self}: flushing audio")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
if self._speech_synthesizer is None:
|
||||
error_msg = "Speech synthesizer not initialized."
|
||||
logger.error(error_msg)
|
||||
yield ErrorFrame(error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
self._speech_synthesizer.speak_ssml_async(ssml)
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
# Stream audio chunks as they arrive
|
||||
while True:
|
||||
chunk = await self._audio_queue.get()
|
||||
if chunk is None: # End of stream
|
||||
break
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSAudioRawFrame(
|
||||
audio=chunk,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error during synthesis: {e}")
|
||||
yield TTSStoppedFrame()
|
||||
# Could add reconnection logic here if needed
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
|
||||
class AzureHttpTTSService(AzureBaseTTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._speech_config = None
|
||||
self._speech_synthesizer = None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
|
||||
if self._speech_config:
|
||||
return
|
||||
|
||||
self._speech_config = SpeechConfig(
|
||||
subscription=self._api_key,
|
||||
region=self._region,
|
||||
speech_recognition_language=self._settings["language"],
|
||||
)
|
||||
self._speech_config.set_speech_synthesis_output_format(
|
||||
sample_rate_to_output_format(self.sample_rate)
|
||||
)
|
||||
self._speech_synthesizer = SpeechSynthesizer(
|
||||
speech_config=self._speech_config, audio_config=None
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
ssml = self._construct_ssml(text)
|
||||
|
||||
result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
|
||||
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:],
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield TTSStoppedFrame()
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == CancellationReason.Error:
|
||||
logger.error(f"{self} error: {cancellation_details.error_details}")
|
||||
13
src/pipecat/services/canonical/__init__.py
Normal file
13
src/pipecat/services/canonical/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .metrics import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics")
|
||||
13
src/pipecat/services/cartesia/__init__.py
Normal file
13
src/pipecat/services/cartesia/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cartesia", "cartesia.tts")
|
||||
@@ -35,9 +35,7 @@ try:
|
||||
from cartesia import AsyncCartesia
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/cerebras/__init__.py
Normal file
13
src/pipecat/services/cerebras/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cerebras", "cerebras.llm")
|
||||
@@ -11,7 +11,7 @@ from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class CerebrasLLMService(OpenAILLMService):
|
||||
14
src/pipecat/services/deepgram/__init__.py
Normal file
14
src/pipecat/services/deepgram/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepgram", "deepgram.[stt,tts]")
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -12,23 +11,18 @@ from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
AsyncListenWebSocketClient,
|
||||
@@ -38,80 +32,13 @@ try:
|
||||
LiveOptions,
|
||||
LiveResultResponse,
|
||||
LiveTranscriptionEvents,
|
||||
SpeakOptions,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice: str = "aura-helios-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self.sample_rate,
|
||||
container="none",
|
||||
)
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self._deepgram_client.speak.v("1").stream, {"text": text}, options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
|
||||
if audio_buffer is None:
|
||||
raise ValueError("No audio data received from Deepgram")
|
||||
|
||||
# Read and yield the audio data in chunks
|
||||
audio_buffer.seek(0) # Ensure we're at the start of the buffer
|
||||
chunk_size = 1024 # Use a fixed buffer size
|
||||
while True:
|
||||
await self.stop_ttfb_metrics()
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1)
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
|
||||
|
||||
class DeepgramSTTService(STTService):
|
||||
def __init__(
|
||||
self,
|
||||
90
src/pipecat/services/deepgram/tts.py
Normal file
90
src/pipecat/services/deepgram/tts.py
Normal file
@@ -0,0 +1,90 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
try:
|
||||
from deepgram import DeepgramClient, SpeakOptions
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class DeepgramTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice: str = "aura-helios-en",
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: str = "linear16",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self.sample_rate,
|
||||
container="none",
|
||||
)
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self._deepgram_client.speak.v("1").stream, {"text": text}, options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
|
||||
if audio_buffer is None:
|
||||
raise ValueError("No audio data received from Deepgram")
|
||||
|
||||
# Read and yield the audio data in chunks
|
||||
audio_buffer.seek(0) # Ensure we're at the start of the buffer
|
||||
chunk_size = 1024 # Use a fixed buffer size
|
||||
while True:
|
||||
await self.stop_ttfb_metrics()
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1)
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
yield ErrorFrame(f"Error getting audio: {str(e)}")
|
||||
13
src/pipecat/services/deepseek/__init__.py
Normal file
13
src/pipecat/services/deepseek/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepseek", "deepseek.llm")
|
||||
@@ -12,7 +12,7 @@ from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class DeepSeekLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/elevenlabs/__init__.py
Normal file
13
src/pipecat/services/elevenlabs/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "elevenlabs", "elevenlabs.tts")
|
||||
@@ -33,9 +33,7 @@ try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`. Also, set `ELEVENLABS_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
ElevenLabsOutputFormat = Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"]
|
||||
14
src/pipecat/services/fal/__init__.py
Normal file
14
src/pipecat/services/fal/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .image import *
|
||||
from .stt import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fal", "fal.[image,stt]")
|
||||
84
src/pipecat/services/fal/image.py
Normal file
84
src/pipecat/services/fal/image.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
from typing import AsyncGenerator, Dict, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
from pipecat.services.ai_services import ImageGenService
|
||||
|
||||
try:
|
||||
import fal_client
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Fal, you need to `pip install pipecat-ai[fal]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class FalImageGenService(ImageGenService):
|
||||
class InputParams(BaseModel):
|
||||
seed: Optional[int] = None
|
||||
num_inference_steps: int = 8
|
||||
num_images: int = 1
|
||||
image_size: Union[str, Dict[str, int]] = "square_hd"
|
||||
expand_prompt: bool = False
|
||||
enable_safety_checker: bool = True
|
||||
format: str = "png"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
params: InputParams,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "fal-ai/fast-sdxl",
|
||||
key: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(model)
|
||||
self._params = params
|
||||
self._aiohttp_session = aiohttp_session
|
||||
if key:
|
||||
os.environ["FAL_KEY"] = key
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
def load_image_bytes(encoded_image: bytes):
|
||||
buffer = io.BytesIO(encoded_image)
|
||||
image = Image.open(buffer)
|
||||
return (image.tobytes(), image.size, image.format)
|
||||
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
response = await fal_client.run_async(
|
||||
self.model_name,
|
||||
arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)},
|
||||
)
|
||||
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
logger.debug(f"Image generated at: {image_url}")
|
||||
|
||||
# Load the image from the url
|
||||
logger.debug(f"Downloading image {image_url} ...")
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
logger.debug(f"Downloaded image {image_url}")
|
||||
encoded_image = await response.content.read()
|
||||
(image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image)
|
||||
|
||||
frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format)
|
||||
yield frame
|
||||
@@ -4,19 +4,14 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from typing import AsyncGenerator, Dict, Optional, Union
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame, URLImageRawFrame
|
||||
from pipecat.services.ai_services import ImageGenService, SegmentedSTTService
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.ai_services import SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -144,65 +139,6 @@ def language_to_fal_language(language: Language) -> Optional[str]:
|
||||
return result
|
||||
|
||||
|
||||
class FalImageGenService(ImageGenService):
|
||||
class InputParams(BaseModel):
|
||||
seed: Optional[int] = None
|
||||
num_inference_steps: int = 8
|
||||
num_images: int = 1
|
||||
image_size: Union[str, Dict[str, int]] = "square_hd"
|
||||
expand_prompt: bool = False
|
||||
enable_safety_checker: bool = True
|
||||
format: str = "png"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
params: InputParams,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
model: str = "fal-ai/fast-sdxl",
|
||||
key: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(model)
|
||||
self._params = params
|
||||
self._aiohttp_session = aiohttp_session
|
||||
if key:
|
||||
os.environ["FAL_KEY"] = key
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
def load_image_bytes(encoded_image: bytes):
|
||||
buffer = io.BytesIO(encoded_image)
|
||||
image = Image.open(buffer)
|
||||
return (image.tobytes(), image.size, image.format)
|
||||
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
response = await fal_client.run_async(
|
||||
self.model_name,
|
||||
arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)},
|
||||
)
|
||||
|
||||
image_url = response["images"][0]["url"] if response else None
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
logger.debug(f"Image generated at: {image_url}")
|
||||
|
||||
# Load the image from the url
|
||||
logger.debug(f"Downloading image {image_url} ...")
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
logger.debug(f"Downloaded image {image_url}")
|
||||
encoded_image = await response.content.read()
|
||||
(image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image)
|
||||
|
||||
frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format)
|
||||
yield frame
|
||||
|
||||
|
||||
class FalSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using Fal's Wizper API.
|
||||
|
||||
13
src/pipecat/services/fireworks/__init__.py
Normal file
13
src/pipecat/services/fireworks/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fireworks", "fireworks.llm")
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class FireworksLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/fish/__init__.py
Normal file
13
src/pipecat/services/fish/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fish", "fish.tts")
|
||||
@@ -30,9 +30,7 @@ try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
# FishAudio supports various output formats
|
||||
@@ -51,7 +51,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import (
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
from .frames import LLMSearchResponseFrame
|
||||
from .google import *
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .frames import *
|
||||
from .image import *
|
||||
from .llm import *
|
||||
from .llm_openai import *
|
||||
from .llm_vertex import *
|
||||
from .rtvi import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(
|
||||
globals(), "google", "google.[frames,image,llm,llm_openai,llm_vertex,rtvi,stt,tts]"
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
95
src/pipecat/services/google/image.py
Normal file
95
src/pipecat/services/google/image.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import io
|
||||
import os
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
from pipecat.services.ai_services import ImageGenService
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GoogleImageGenService(ImageGenService):
|
||||
class InputParams(BaseModel):
|
||||
number_of_images: int = Field(default=1, ge=1, le=8)
|
||||
model: str = Field(default="imagen-3.0-generate-002")
|
||||
negative_prompt: str = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
params: InputParams = InputParams(),
|
||||
api_key: str,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(params.model)
|
||||
self._params = params
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate images from a text prompt using Google's Imagen model.
|
||||
|
||||
Args:
|
||||
prompt (str): The text description to generate images from.
|
||||
|
||||
Yields:
|
||||
Frame: Generated image frames or error frames.
|
||||
"""
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_images(
|
||||
model=self._params.model,
|
||||
prompt=prompt,
|
||||
config=types.GenerateImagesConfig(
|
||||
number_of_images=self._params.number_of_images,
|
||||
negative_prompt=self._params.negative_prompt,
|
||||
),
|
||||
)
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not response or not response.generated_images:
|
||||
logger.error(f"{self} error: image generation failed")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
for img_response in response.generated_images:
|
||||
# Google returns the image data directly
|
||||
image_bytes = img_response.image.image_bytes
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
frame = URLImageRawFrame(
|
||||
url=None, # Google doesn't provide URLs, only image data
|
||||
image=image.tobytes(),
|
||||
size=image.size,
|
||||
format=image.format,
|
||||
)
|
||||
yield frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error generating image: {e}")
|
||||
yield ErrorFrame(f"Image generation error: {str(e)}")
|
||||
717
src/pipecat/services/google/llm.py
Normal file
717
src/pipecat/services/google/llm.py
Normal file
@@ -0,0 +1,717 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from google.api_core.exceptions import DeadlineExceeded
|
||||
|
||||
from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.google.frames import LLMSearchResponseFrame
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as gai
|
||||
from google.generativeai.types import GenerationConfig
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GoogleUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def push_aggregation(self):
|
||||
if len(self._aggregation) > 0:
|
||||
self._context.add_message(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
|
||||
)
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
|
||||
# Push context frame
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
self.reset()
|
||||
|
||||
|
||||
class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
async def handle_aggregation(self, aggregation: str):
|
||||
self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
self._context.add_message(
|
||||
glm.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_call=glm.FunctionCall(
|
||||
id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
self._context.add_message(
|
||||
glm.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
id=frame.tool_call_id,
|
||||
name=frame.function_name,
|
||||
response={"response": "IN_PROGRESS"},
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, frame.result
|
||||
)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if message.role == "user":
|
||||
for part in message.parts:
|
||||
if part.function_response and part.function_response.id == tool_call_id:
|
||||
part.function_response.response = {"value": json.dumps(result)}
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleContextAggregatorPair:
|
||||
_user: GoogleUserContextAggregator
|
||||
_assistant: GoogleAssistantContextAggregator
|
||||
|
||||
def user(self) -> GoogleUserContextAggregator:
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> GoogleAssistantContextAggregator:
|
||||
return self._assistant
|
||||
|
||||
|
||||
class GoogleLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: Optional[List[dict]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self.system_message = None
|
||||
|
||||
@staticmethod
|
||||
def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
|
||||
if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
|
||||
logger.debug(f"Upgrading to Google: {obj}")
|
||||
obj.__class__ = GoogleLLMContext
|
||||
obj._restructure_from_openai_messages()
|
||||
return obj
|
||||
|
||||
def set_messages(self, messages: List):
|
||||
self._messages[:] = messages
|
||||
self._restructure_from_openai_messages()
|
||||
|
||||
def add_messages(self, messages: List):
|
||||
# Convert each message individually
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, glm.Content):
|
||||
# Already in Gemini format
|
||||
converted_messages.append(msg)
|
||||
else:
|
||||
# Convert from standard format to Gemini format
|
||||
converted = self.from_standard_message(msg)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Add the converted messages to our existing messages
|
||||
self._messages.extend(converted_messages)
|
||||
|
||||
def get_messages_for_logging(self):
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
obj = glm.Content.to_dict(message)
|
||||
try:
|
||||
if "parts" in obj:
|
||||
for part in obj["parts"]:
|
||||
if "inline_data" in part:
|
||||
part["inline_data"]["data"] = "..."
|
||||
except Exception as e:
|
||||
logger.debug(f"Error: {e}")
|
||||
msgs.append(obj)
|
||||
return msgs
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
|
||||
parts = []
|
||||
if text:
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
|
||||
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
|
||||
):
|
||||
if not audio_frames:
|
||||
return
|
||||
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
num_channels = audio_frames[0].num_channels
|
||||
|
||||
parts = []
|
||||
data = b"".join(frame.audio for frame in audio_frames)
|
||||
# NOTE(aleix): According to the docs only text or inline_data should be needed.
|
||||
# (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
|
||||
parts.append(glm.Part(text=text))
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
mime_type="audio/wav",
|
||||
data=(
|
||||
bytes(
|
||||
self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
self.add_message(glm.Content(role="user", parts=parts))
|
||||
# message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
|
||||
# self.add_message(message)
|
||||
|
||||
def from_standard_message(self, message):
|
||||
"""Convert standard format message to Google Content object.
|
||||
|
||||
Handles conversion of text, images, and function calls to Google's format.
|
||||
System messages are stored separately and return None.
|
||||
|
||||
Args:
|
||||
message: Message in standard format:
|
||||
{
|
||||
"role": "user/assistant/system/tool",
|
||||
"content": str | [{"type": "text/image_url", ...}] | None,
|
||||
"tool_calls": [{"function": {"name": str, "arguments": str}}]
|
||||
}
|
||||
|
||||
Returns:
|
||||
glm.Content object with:
|
||||
- role: "user" or "model" (converted from "assistant")
|
||||
- parts: List[Part] containing text, inline_data, or function calls
|
||||
Returns None for system messages.
|
||||
"""
|
||||
role = message["role"]
|
||||
content = message.get("content", [])
|
||||
if role == "system":
|
||||
self.system_message = content
|
||||
return None
|
||||
elif role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
parts.append(
|
||||
glm.Part(
|
||||
function_call=glm.FunctionCall(
|
||||
name=tc["function"]["name"],
|
||||
args=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
parts.append(
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name="tool_call_result", # seems to work to hard-code the same name every time
|
||||
response=json.loads(message["content"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(content, str):
|
||||
parts.append(glm.Part(text=content))
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
parts.append(glm.Part(text=c["text"]))
|
||||
elif c["type"] == "image_url":
|
||||
parts.append(
|
||||
glm.Part(
|
||||
inline_data=glm.Blob(
|
||||
mime_type="image/jpeg",
|
||||
data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
message = glm.Content(role=role, parts=parts)
|
||||
return message
|
||||
|
||||
def to_standard_messages(self, obj) -> list:
|
||||
"""Convert Google Content object to standard structured format.
|
||||
|
||||
Handles text, images, and function calls from Google's Content/Part objects.
|
||||
|
||||
Args:
|
||||
obj: Google Content object with:
|
||||
- role: "model" (converted to "assistant") or "user"
|
||||
- parts: List[Part] containing text, inline_data, or function calls
|
||||
|
||||
Returns:
|
||||
List of messages in standard format:
|
||||
[
|
||||
{
|
||||
"role": "user/assistant/tool",
|
||||
"content": [
|
||||
{"type": "text", "text": str} |
|
||||
{"type": "image_url", "image_url": {"url": str}}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
msg = {"role": obj.role, "content": []}
|
||||
if msg["role"] == "model":
|
||||
msg["role"] = "assistant"
|
||||
|
||||
for part in obj.parts:
|
||||
if part.text:
|
||||
msg["content"].append({"type": "text", "text": part.text})
|
||||
elif part.inline_data:
|
||||
encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
|
||||
}
|
||||
)
|
||||
elif part.function_call:
|
||||
args = type(part.function_call).to_dict(part.function_call).get("args", {})
|
||||
msg["tool_calls"] = [
|
||||
{
|
||||
"id": part.function_call.name,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": part.function_call.name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
elif part.function_response:
|
||||
msg["role"] = "tool"
|
||||
resp = (
|
||||
type(part.function_response).to_dict(part.function_response).get("response", {})
|
||||
)
|
||||
msg["tool_call_id"] = part.function_response.name
|
||||
msg["content"] = json.dumps(resp)
|
||||
|
||||
# there might be no content parts for tool_calls messages
|
||||
if not msg["content"]:
|
||||
del msg["content"]
|
||||
return [msg]
|
||||
|
||||
def _restructure_from_openai_messages(self):
|
||||
"""Restructures messages to ensure proper Google format and message ordering.
|
||||
|
||||
This method handles conversion of OpenAI-formatted messages to Google format,
|
||||
with special handling for function calls, function responses, and system messages.
|
||||
System messages are added back to the context as user messages when needed.
|
||||
|
||||
The final message order is preserved as:
|
||||
1. Function calls (from model)
|
||||
2. Function responses (from user)
|
||||
3. Text messages (converted from system messages)
|
||||
|
||||
Note:
|
||||
System messages are only added back when there are no regular text
|
||||
messages in the context, ensuring proper conversation continuity
|
||||
after function calls.
|
||||
"""
|
||||
self.system_message = None
|
||||
converted_messages = []
|
||||
|
||||
# Process each message, preserving Google-formatted messages and converting others
|
||||
for message in self._messages:
|
||||
if isinstance(message, glm.Content):
|
||||
# Keep existing Google-formatted messages (e.g., function calls/responses)
|
||||
converted_messages.append(message)
|
||||
continue
|
||||
|
||||
# Convert OpenAI format to Google format, system messages return None
|
||||
converted = self.from_standard_message(message)
|
||||
if converted is not None:
|
||||
converted_messages.append(converted)
|
||||
|
||||
# Update message list
|
||||
self._messages[:] = converted_messages
|
||||
|
||||
# Check if we only have function-related messages (no regular text)
|
||||
has_regular_messages = any(
|
||||
len(msg.parts) == 1
|
||||
and not getattr(msg.parts[0], "text", None)
|
||||
and getattr(msg.parts[0], "function_call", None)
|
||||
and getattr(msg.parts[0], "function_response", None)
|
||||
for msg in self._messages
|
||||
)
|
||||
|
||||
# Add system message back as a user message if we only have function messages
|
||||
if self.system_message and not has_regular_messages:
|
||||
self._messages.append(
|
||||
glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
|
||||
)
|
||||
|
||||
# Remove any empty messages
|
||||
self._messages = [m for m in self._messages if m.parts]
|
||||
|
||||
|
||||
class GoogleLLMService(LLMService):
|
||||
"""This class implements inference with Google's AI models.
|
||||
|
||||
This service translates internally from OpenAILLMContext to the messages format
|
||||
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
|
||||
franca for all LLM services, so that it is easy to switch between different LLMs.
|
||||
"""
|
||||
|
||||
# Overriding the default adapter to use the Gemini one.
|
||||
adapter_class = GeminiLLMAdapter
|
||||
|
||||
class InputParams(BaseModel):
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
params: InputParams = InputParams(),
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
gai.configure(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._system_instruction = system_instruction
|
||||
self._create_client()
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self._tools = tools
|
||||
self._tool_config = tool_config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def _create_client(self):
|
||||
self._client = gai.GenerativeModel(
|
||||
self._model_name, system_instruction=self._system_instruction
|
||||
)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
grounding_metadata = None
|
||||
search_result = ""
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
# f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
|
||||
f"{self}: Generating chat [{context.get_messages_for_logging()}]"
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if context.system_message and self._system_instruction != context.system_message:
|
||||
logger.debug(f"System instruction changed: {context.system_message}")
|
||||
self._system_instruction = context.system_message
|
||||
self._create_client()
|
||||
|
||||
# Filter out None values and create GenerationConfig
|
||||
generation_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"max_output_tokens": self._settings["max_tokens"],
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
generation_config = GenerationConfig(**generation_params) if generation_params else None
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
tools = []
|
||||
if context.tools:
|
||||
tools = context.tools
|
||||
elif self._tools:
|
||||
tools = self._tools
|
||||
tool_config = None
|
||||
if self._tool_config:
|
||||
tool_config = self._tool_config
|
||||
response = await self._client.generate_content_async(
|
||||
contents=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
generation_config=generation_config,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if response.usage_metadata:
|
||||
# Use only the prompt token count from the response object
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
total_tokens = prompt_tokens
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.usage_metadata:
|
||||
# Use only the completion_tokens from the chunks. Prompt tokens are already counted and
|
||||
# are repeated here.
|
||||
completion_tokens += chunk.usage_metadata.candidates_token_count
|
||||
total_tokens += chunk.usage_metadata.candidates_token_count
|
||||
try:
|
||||
for c in chunk.parts:
|
||||
if c.text:
|
||||
search_result += c.text
|
||||
await self.push_frame(LLMTextFrame(c.text))
|
||||
elif c.function_call:
|
||||
logger.debug(f"Function call: {c.function_call}")
|
||||
args = type(c.function_call).to_dict(c.function_call).get("args", {})
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=c.function_call.name,
|
||||
arguments=args,
|
||||
)
|
||||
# Handle grounding metadata
|
||||
# It seems only the last chunk that we receive may contain this information
|
||||
# If the response doesn't include groundingMetadata, this means the response wasn't grounded.
|
||||
if chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
# logger.debug(f"candidate received: {candidate}")
|
||||
# Extract grounding metadata
|
||||
grounding_metadata = (
|
||||
{
|
||||
"rendered_content": getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"search_entry_point",
|
||||
None,
|
||||
).rendered_content
|
||||
if hasattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"search_entry_point",
|
||||
)
|
||||
else None,
|
||||
"origins": [
|
||||
{
|
||||
"site_uri": getattr(grounding_chunk.web, "uri", None),
|
||||
"site_title": getattr(
|
||||
grounding_chunk.web, "title", None
|
||||
),
|
||||
"results": [
|
||||
{
|
||||
"text": getattr(
|
||||
grounding_support.segment, "text", ""
|
||||
),
|
||||
"confidence": getattr(
|
||||
grounding_support, "confidence_scores", None
|
||||
),
|
||||
}
|
||||
for grounding_support in getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"grounding_supports",
|
||||
[],
|
||||
)
|
||||
if index
|
||||
in getattr(
|
||||
grounding_support, "grounding_chunk_indices", []
|
||||
)
|
||||
],
|
||||
}
|
||||
for index, grounding_chunk in enumerate(
|
||||
getattr(
|
||||
getattr(candidate, "grounding_metadata", None),
|
||||
"grounding_chunks",
|
||||
[],
|
||||
)
|
||||
)
|
||||
],
|
||||
}
|
||||
if getattr(candidate, "grounding_metadata", None)
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
# Google LLMs seem to flag safety issues a lot!
|
||||
if chunk.candidates[0].finish_reason == 3:
|
||||
logger.debug(
|
||||
f"LLM refused to generate content for safety reasons - {messages}."
|
||||
)
|
||||
else:
|
||||
logger.exception(f"{self} error: {e}")
|
||||
|
||||
except DeadlineExceeded:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
if grounding_metadata is not None and isinstance(grounding_metadata, dict):
|
||||
llm_search_frame = LLMSearchResponseFrame(
|
||||
search_result=search_result,
|
||||
origins=grounding_metadata["origins"],
|
||||
rendered_content=grounding_metadata["rendered_content"],
|
||||
)
|
||||
await self.push_frame(llm_search_frame)
|
||||
|
||||
await self.start_llm_usage_metrics(
|
||||
LLMTokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
)
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = GoogleLLMContext.upgrade_to_google(frame.context)
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = GoogleLLMContext(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = GoogleLLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
) -> GoogleContextAggregatorPair:
|
||||
"""Create an instance of GoogleContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
|
||||
Returns:
|
||||
GoogleContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
GoogleContextAggregatorPair.
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
context = GoogleLLMContext.upgrade_to_google(context)
|
||||
user = GoogleUserContextAggregator(context, **user_kwargs)
|
||||
assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
|
||||
return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
136
src/pipecat/services/google/llm_openai.py
Normal file
136
src/pipecat/services/google/llm_openai.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import OpenAIUnhandledFunctionException
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class GoogleLLMOpenAIBetaService(OpenAILLMService):
|
||||
"""This class implements inference with Google's AI LLM models using the OpenAI format.
|
||||
Ref - https://ai.google.dev/gemini-api/docs/openai
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model: str = "gemini-2.0-flash",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
|
||||
context
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
logger.debug(
|
||||
f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
|
||||
)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if function_name == "":
|
||||
# TODO: Remove the _process_context method once Google resolves the bug
|
||||
# where the index is incorrectly set to None instead of returning the actual index,
|
||||
# which currently results in an empty function name('').
|
||||
continue
|
||||
if self.has_function(function_name):
|
||||
run_llm = False
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
107
src/pipecat/services/google/llm_vertex.py
Normal file
107
src/pipecat/services/google/llm_vertex.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GoogleVertexLLMService(OpenAILLMService):
|
||||
"""Implements inference with Google's AI models via Vertex AI while
|
||||
maintaining OpenAI API compatibility.
|
||||
|
||||
Reference:
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library
|
||||
|
||||
"""
|
||||
|
||||
class InputParams(OpenAILLMService.InputParams):
|
||||
"""Input parameters specific to Vertex AI."""
|
||||
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations
|
||||
location: str = "us-east4"
|
||||
project_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
model: str = "google/gemini-2.0-flash-001",
|
||||
params: InputParams = OpenAILLMService.InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes the VertexLLMService.
|
||||
|
||||
Args:
|
||||
credentials (Optional[str]): JSON string of service account credentials.
|
||||
credentials_path (Optional[str]): Path to the service account JSON file.
|
||||
model (str): Model identifier. Defaults to "google/gemini-2.0-flash-001".
|
||||
params (InputParams): Vertex AI input parameters.
|
||||
**kwargs: Additional arguments for OpenAILLMService.
|
||||
"""
|
||||
base_url = self._get_base_url(params)
|
||||
self._api_key = self._get_api_token(credentials, credentials_path)
|
||||
|
||||
super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_base_url(params: InputParams) -> str:
|
||||
"""Constructs the base URL for Vertex AI API."""
|
||||
return (
|
||||
f"https://{params.location}-aiplatform.googleapis.com/v1/"
|
||||
f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
|
||||
"""Retrieves an authentication token using Google service account credentials.
|
||||
|
||||
Args:
|
||||
credentials (Optional[str]): JSON string of service account credentials.
|
||||
credentials_path (Optional[str]): Path to the service account JSON file.
|
||||
|
||||
Returns:
|
||||
str: OAuth token for API authentication.
|
||||
"""
|
||||
creds: Optional[service_account.Credentials] = None
|
||||
|
||||
if credentials:
|
||||
# Parse and load credentials from JSON string
|
||||
creds = service_account.Credentials.from_service_account_info(
|
||||
json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
elif credentials_path:
|
||||
# Load credentials from JSON file
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour.
|
||||
|
||||
return creds.token
|
||||
806
src/pipecat/services/google/stt.py
Normal file
806
src/pipecat/services/google/stt.py
Normal file
@@ -0,0 +1,806 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.cloud import speech_v2
|
||||
from google.cloud.speech_v2.types import cloud_speech
|
||||
from google.oauth2 import service_account
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_google_stt_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to Google Speech-to-Text V2 language codes.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Google STT language code or None if not supported.
|
||||
"""
|
||||
language_map = {
|
||||
# Afrikaans
|
||||
Language.AF: "af-ZA",
|
||||
Language.AF_ZA: "af-ZA",
|
||||
# Albanian
|
||||
Language.SQ: "sq-AL",
|
||||
Language.SQ_AL: "sq-AL",
|
||||
# Amharic
|
||||
Language.AM: "am-ET",
|
||||
Language.AM_ET: "am-ET",
|
||||
# Arabic
|
||||
Language.AR: "ar-EG", # Default to Egypt
|
||||
Language.AR_AE: "ar-AE",
|
||||
Language.AR_BH: "ar-BH",
|
||||
Language.AR_DZ: "ar-DZ",
|
||||
Language.AR_EG: "ar-EG",
|
||||
Language.AR_IQ: "ar-IQ",
|
||||
Language.AR_JO: "ar-JO",
|
||||
Language.AR_KW: "ar-KW",
|
||||
Language.AR_LB: "ar-LB",
|
||||
Language.AR_MA: "ar-MA",
|
||||
Language.AR_OM: "ar-OM",
|
||||
Language.AR_QA: "ar-QA",
|
||||
Language.AR_SA: "ar-SA",
|
||||
Language.AR_SY: "ar-SY",
|
||||
Language.AR_TN: "ar-TN",
|
||||
Language.AR_YE: "ar-YE",
|
||||
# Armenian
|
||||
Language.HY: "hy-AM",
|
||||
Language.HY_AM: "hy-AM",
|
||||
# Azerbaijani
|
||||
Language.AZ: "az-AZ",
|
||||
Language.AZ_AZ: "az-AZ",
|
||||
# Basque
|
||||
Language.EU: "eu-ES",
|
||||
Language.EU_ES: "eu-ES",
|
||||
# Bengali
|
||||
Language.BN: "bn-IN", # Default to India
|
||||
Language.BN_BD: "bn-BD",
|
||||
Language.BN_IN: "bn-IN",
|
||||
# Bosnian
|
||||
Language.BS: "bs-BA",
|
||||
Language.BS_BA: "bs-BA",
|
||||
# Bulgarian
|
||||
Language.BG: "bg-BG",
|
||||
Language.BG_BG: "bg-BG",
|
||||
# Burmese
|
||||
Language.MY: "my-MM",
|
||||
Language.MY_MM: "my-MM",
|
||||
# Catalan
|
||||
Language.CA: "ca-ES",
|
||||
Language.CA_ES: "ca-ES",
|
||||
# Chinese
|
||||
Language.ZH: "cmn-Hans-CN", # Default to Simplified Chinese
|
||||
Language.ZH_CN: "cmn-Hans-CN",
|
||||
Language.ZH_HK: "cmn-Hans-HK",
|
||||
Language.ZH_TW: "cmn-Hant-TW",
|
||||
Language.YUE: "yue-Hant-HK", # Cantonese
|
||||
Language.YUE_CN: "yue-Hant-HK",
|
||||
# Croatian
|
||||
Language.HR: "hr-HR",
|
||||
Language.HR_HR: "hr-HR",
|
||||
# Czech
|
||||
Language.CS: "cs-CZ",
|
||||
Language.CS_CZ: "cs-CZ",
|
||||
# Danish
|
||||
Language.DA: "da-DK",
|
||||
Language.DA_DK: "da-DK",
|
||||
# Dutch
|
||||
Language.NL: "nl-NL", # Default to Netherlands
|
||||
Language.NL_BE: "nl-BE",
|
||||
Language.NL_NL: "nl-NL",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_AU: "en-AU",
|
||||
Language.EN_CA: "en-CA",
|
||||
Language.EN_GB: "en-GB",
|
||||
Language.EN_GH: "en-GH",
|
||||
Language.EN_HK: "en-HK",
|
||||
Language.EN_IN: "en-IN",
|
||||
Language.EN_IE: "en-IE",
|
||||
Language.EN_KE: "en-KE",
|
||||
Language.EN_NG: "en-NG",
|
||||
Language.EN_NZ: "en-NZ",
|
||||
Language.EN_PH: "en-PH",
|
||||
Language.EN_SG: "en-SG",
|
||||
Language.EN_TZ: "en-TZ",
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_ZA: "en-ZA",
|
||||
# Estonian
|
||||
Language.ET: "et-EE",
|
||||
Language.ET_EE: "et-EE",
|
||||
# Filipino
|
||||
Language.FIL: "fil-PH",
|
||||
Language.FIL_PH: "fil-PH",
|
||||
# Finnish
|
||||
Language.FI: "fi-FI",
|
||||
Language.FI_FI: "fi-FI",
|
||||
# French
|
||||
Language.FR: "fr-FR", # Default to France
|
||||
Language.FR_BE: "fr-BE",
|
||||
Language.FR_CA: "fr-CA",
|
||||
Language.FR_CH: "fr-CH",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# Galician
|
||||
Language.GL: "gl-ES",
|
||||
Language.GL_ES: "gl-ES",
|
||||
# Georgian
|
||||
Language.KA: "ka-GE",
|
||||
Language.KA_GE: "ka-GE",
|
||||
# German
|
||||
Language.DE: "de-DE", # Default to Germany
|
||||
Language.DE_AT: "de-AT",
|
||||
Language.DE_CH: "de-CH",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Greek
|
||||
Language.EL: "el-GR",
|
||||
Language.EL_GR: "el-GR",
|
||||
# Gujarati
|
||||
Language.GU: "gu-IN",
|
||||
Language.GU_IN: "gu-IN",
|
||||
# Hebrew
|
||||
Language.HE: "iw-IL",
|
||||
Language.HE_IL: "iw-IL",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Hungarian
|
||||
Language.HU: "hu-HU",
|
||||
Language.HU_HU: "hu-HU",
|
||||
# Icelandic
|
||||
Language.IS: "is-IS",
|
||||
Language.IS_IS: "is-IS",
|
||||
# Indonesian
|
||||
Language.ID: "id-ID",
|
||||
Language.ID_ID: "id-ID",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
Language.IT_CH: "it-CH",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Javanese
|
||||
Language.JV: "jv-ID",
|
||||
Language.JV_ID: "jv-ID",
|
||||
# Kannada
|
||||
Language.KN: "kn-IN",
|
||||
Language.KN_IN: "kn-IN",
|
||||
# Kazakh
|
||||
Language.KK: "kk-KZ",
|
||||
Language.KK_KZ: "kk-KZ",
|
||||
# Khmer
|
||||
Language.KM: "km-KH",
|
||||
Language.KM_KH: "km-KH",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Lao
|
||||
Language.LO: "lo-LA",
|
||||
Language.LO_LA: "lo-LA",
|
||||
# Latvian
|
||||
Language.LV: "lv-LV",
|
||||
Language.LV_LV: "lv-LV",
|
||||
# Lithuanian
|
||||
Language.LT: "lt-LT",
|
||||
Language.LT_LT: "lt-LT",
|
||||
# Macedonian
|
||||
Language.MK: "mk-MK",
|
||||
Language.MK_MK: "mk-MK",
|
||||
# Malay
|
||||
Language.MS: "ms-MY",
|
||||
Language.MS_MY: "ms-MY",
|
||||
# Malayalam
|
||||
Language.ML: "ml-IN",
|
||||
Language.ML_IN: "ml-IN",
|
||||
# Marathi
|
||||
Language.MR: "mr-IN",
|
||||
Language.MR_IN: "mr-IN",
|
||||
# Mongolian
|
||||
Language.MN: "mn-MN",
|
||||
Language.MN_MN: "mn-MN",
|
||||
# Nepali
|
||||
Language.NE: "ne-NP",
|
||||
Language.NE_NP: "ne-NP",
|
||||
# Norwegian
|
||||
Language.NO: "no-NO",
|
||||
Language.NB: "no-NO",
|
||||
Language.NB_NO: "no-NO",
|
||||
# Persian
|
||||
Language.FA: "fa-IR",
|
||||
Language.FA_IR: "fa-IR",
|
||||
# Polish
|
||||
Language.PL: "pl-PL",
|
||||
Language.PL_PL: "pl-PL",
|
||||
# Portuguese
|
||||
Language.PT: "pt-PT", # Default to Portugal
|
||||
Language.PT_BR: "pt-BR",
|
||||
Language.PT_PT: "pt-PT",
|
||||
# Punjabi
|
||||
Language.PA: "pa-Guru-IN",
|
||||
Language.PA_IN: "pa-Guru-IN",
|
||||
# Romanian
|
||||
Language.RO: "ro-RO",
|
||||
Language.RO_RO: "ro-RO",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Serbian
|
||||
Language.SR: "sr-RS",
|
||||
Language.SR_RS: "sr-RS",
|
||||
# Sinhala
|
||||
Language.SI: "si-LK",
|
||||
Language.SI_LK: "si-LK",
|
||||
# Slovak
|
||||
Language.SK: "sk-SK",
|
||||
Language.SK_SK: "sk-SK",
|
||||
# Slovenian
|
||||
Language.SL: "sl-SI",
|
||||
Language.SL_SI: "sl-SI",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_AR: "es-AR",
|
||||
Language.ES_BO: "es-BO",
|
||||
Language.ES_CL: "es-CL",
|
||||
Language.ES_CO: "es-CO",
|
||||
Language.ES_CR: "es-CR",
|
||||
Language.ES_DO: "es-DO",
|
||||
Language.ES_EC: "es-EC",
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_GT: "es-GT",
|
||||
Language.ES_HN: "es-HN",
|
||||
Language.ES_MX: "es-MX",
|
||||
Language.ES_NI: "es-NI",
|
||||
Language.ES_PA: "es-PA",
|
||||
Language.ES_PE: "es-PE",
|
||||
Language.ES_PR: "es-PR",
|
||||
Language.ES_PY: "es-PY",
|
||||
Language.ES_SV: "es-SV",
|
||||
Language.ES_US: "es-US",
|
||||
Language.ES_UY: "es-UY",
|
||||
Language.ES_VE: "es-VE",
|
||||
# Sundanese
|
||||
Language.SU: "su-ID",
|
||||
Language.SU_ID: "su-ID",
|
||||
# Swahili
|
||||
Language.SW: "sw-TZ", # Default to Tanzania
|
||||
Language.SW_KE: "sw-KE",
|
||||
Language.SW_TZ: "sw-TZ",
|
||||
# Swedish
|
||||
Language.SV: "sv-SE",
|
||||
Language.SV_SE: "sv-SE",
|
||||
# Tamil
|
||||
Language.TA: "ta-IN", # Default to India
|
||||
Language.TA_IN: "ta-IN",
|
||||
Language.TA_MY: "ta-MY",
|
||||
Language.TA_SG: "ta-SG",
|
||||
Language.TA_LK: "ta-LK",
|
||||
# Telugu
|
||||
Language.TE: "te-IN",
|
||||
Language.TE_IN: "te-IN",
|
||||
# Thai
|
||||
Language.TH: "th-TH",
|
||||
Language.TH_TH: "th-TH",
|
||||
# Turkish
|
||||
Language.TR: "tr-TR",
|
||||
Language.TR_TR: "tr-TR",
|
||||
# Ukrainian
|
||||
Language.UK: "uk-UA",
|
||||
Language.UK_UA: "uk-UA",
|
||||
# Urdu
|
||||
Language.UR: "ur-IN", # Default to India
|
||||
Language.UR_IN: "ur-IN",
|
||||
Language.UR_PK: "ur-PK",
|
||||
# Uzbek
|
||||
Language.UZ: "uz-UZ",
|
||||
Language.UZ_UZ: "uz-UZ",
|
||||
# Vietnamese
|
||||
Language.VI: "vi-VN",
|
||||
Language.VI_VN: "vi-VN",
|
||||
# Xhosa
|
||||
Language.XH: "xh-ZA",
|
||||
# Zulu
|
||||
Language.ZU: "zu-ZA",
|
||||
Language.ZU_ZA: "zu-ZA",
|
||||
}
|
||||
|
||||
return language_map.get(language)
|
||||
|
||||
|
||||
class GoogleSTTService(STTService):
|
||||
"""Google Cloud Speech-to-Text V2 service implementation.
|
||||
|
||||
Provides real-time speech recognition using Google Cloud's Speech-to-Text V2 API
|
||||
with streaming support. Handles audio transcription and optional voice activity detection.
|
||||
|
||||
Attributes:
|
||||
InputParams: Configuration parameters for the STT service.
|
||||
"""
|
||||
|
||||
# Google Cloud's STT service has a connection time limit of 5 minutes per stream.
|
||||
# They've shared an "endless streaming" example that guided this implementation:
|
||||
# https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming
|
||||
|
||||
STREAMING_LIMIT = 240000 # 4 minutes in milliseconds
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Google Speech-to-Text.
|
||||
|
||||
Attributes:
|
||||
languages: Single language or list of recognition languages. First language is primary.
|
||||
model: Speech recognition model to use.
|
||||
use_separate_recognition_per_channel: Process each audio channel separately.
|
||||
enable_automatic_punctuation: Add punctuation to transcripts.
|
||||
enable_spoken_punctuation: Include spoken punctuation in transcript.
|
||||
enable_spoken_emojis: Include spoken emojis in transcript.
|
||||
profanity_filter: Filter profanity from transcript.
|
||||
enable_word_time_offsets: Include timing information for each word.
|
||||
enable_word_confidence: Include confidence scores for each word.
|
||||
enable_interim_results: Stream partial recognition results.
|
||||
enable_voice_activity_events: Detect voice activity in audio.
|
||||
"""
|
||||
|
||||
languages: Union[Language, List[Language]] = Field(default_factory=lambda: [Language.EN_US])
|
||||
model: Optional[str] = "latest_long"
|
||||
use_separate_recognition_per_channel: Optional[bool] = False
|
||||
enable_automatic_punctuation: Optional[bool] = True
|
||||
enable_spoken_punctuation: Optional[bool] = False
|
||||
enable_spoken_emojis: Optional[bool] = False
|
||||
profanity_filter: Optional[bool] = False
|
||||
enable_word_time_offsets: Optional[bool] = False
|
||||
enable_word_confidence: Optional[bool] = False
|
||||
enable_interim_results: Optional[bool] = True
|
||||
enable_voice_activity_events: Optional[bool] = False
|
||||
|
||||
@field_validator("languages", mode="before")
|
||||
@classmethod
|
||||
def validate_languages(cls, v) -> List[Language]:
|
||||
if isinstance(v, Language):
|
||||
return [v]
|
||||
return v
|
||||
|
||||
@property
|
||||
def language_list(self) -> List[Language]:
|
||||
"""Get languages as a guaranteed list."""
|
||||
assert isinstance(self.languages, list)
|
||||
return self.languages
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: str = "global",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Google STT service.
|
||||
|
||||
Args:
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to service account credentials JSON file.
|
||||
location: Google Cloud location (e.g., "global", "us-central1").
|
||||
sample_rate: Audio sample rate in Hertz.
|
||||
params: Configuration parameters for the service.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither credentials nor credentials_path is provided.
|
||||
ValueError: If project ID is not found in credentials.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._location = location
|
||||
self._stream = None
|
||||
self._config = None
|
||||
self._request_queue = asyncio.Queue()
|
||||
self._streaming_task = None
|
||||
|
||||
# Used for keep-alive logic
|
||||
self._stream_start_time = 0
|
||||
self._last_audio_input = []
|
||||
self._audio_input = []
|
||||
self._result_end_time = 0
|
||||
self._is_final_end_time = 0
|
||||
self._final_request_end_time = 0
|
||||
self._bridging_offset = 0
|
||||
self._last_transcript_was_final = False
|
||||
self._new_stream = True
|
||||
self._restart_counter = 0
|
||||
|
||||
# Configure client options based on location
|
||||
client_options = None
|
||||
if self._location != "global":
|
||||
client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
|
||||
|
||||
# Extract project ID and create client
|
||||
if credentials:
|
||||
json_account_info = json.loads(credentials)
|
||||
self._project_id = json_account_info.get("project_id")
|
||||
creds = service_account.Credentials.from_service_account_info(json_account_info)
|
||||
elif credentials_path:
|
||||
with open(credentials_path) as f:
|
||||
json_account_info = json.load(f)
|
||||
self._project_id = json_account_info.get("project_id")
|
||||
creds = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
else:
|
||||
raise ValueError("Either credentials or credentials_path must be provided")
|
||||
|
||||
if not self._project_id:
|
||||
raise ValueError("Project ID not found in credentials")
|
||||
|
||||
self._client = speech_v2.SpeechAsyncClient(credentials=creds, client_options=client_options)
|
||||
|
||||
self._settings = {
|
||||
"language_codes": [
|
||||
self.language_to_service_language(lang) for lang in params.language_list
|
||||
],
|
||||
"model": params.model,
|
||||
"use_separate_recognition_per_channel": params.use_separate_recognition_per_channel,
|
||||
"enable_automatic_punctuation": params.enable_automatic_punctuation,
|
||||
"enable_spoken_punctuation": params.enable_spoken_punctuation,
|
||||
"enable_spoken_emojis": params.enable_spoken_emojis,
|
||||
"profanity_filter": params.profanity_filter,
|
||||
"enable_word_time_offsets": params.enable_word_time_offsets,
|
||||
"enable_word_confidence": params.enable_word_confidence,
|
||||
"enable_interim_results": params.enable_interim_results,
|
||||
"enable_voice_activity_events": params.enable_voice_activity_events,
|
||||
}
|
||||
|
||||
def language_to_service_language(self, language: Language | List[Language]) -> str | List[str]:
|
||||
"""Convert Language enum(s) to Google STT language code(s).
|
||||
|
||||
Args:
|
||||
language: Single Language enum or list of Language enums.
|
||||
|
||||
Returns:
|
||||
str | List[str]: Google STT language code(s).
|
||||
"""
|
||||
if isinstance(language, list):
|
||||
return [language_to_google_stt_language(lang) or "en-US" for lang in language]
|
||||
return language_to_google_stt_language(language) or "en-US"
|
||||
|
||||
async def _reconnect_if_needed(self):
|
||||
"""Reconnect the stream if it's currently active."""
|
||||
if self._streaming_task:
|
||||
logger.debug("Reconnecting stream due to configuration changes")
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Update the service's recognition language.
|
||||
|
||||
A convenience method for setting a single language.
|
||||
|
||||
Args:
|
||||
language: New language for recognition.
|
||||
"""
|
||||
logger.debug(f"Switching STT language to: {language}")
|
||||
await self.set_languages([language])
|
||||
|
||||
async def set_languages(self, languages: List[Language]):
|
||||
"""Update the service's recognition languages.
|
||||
|
||||
Args:
|
||||
languages: List of languages for recognition. First language is primary.
|
||||
"""
|
||||
logger.debug(f"Switching STT languages to: {languages}")
|
||||
self._settings["language_codes"] = [
|
||||
self.language_to_service_language(lang) for lang in languages
|
||||
]
|
||||
# Recreate stream with new languages
|
||||
await self._reconnect_if_needed()
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Update the service's recognition model."""
|
||||
logger.debug(f"Switching STT model to: {model}")
|
||||
await super().set_model(model)
|
||||
self._settings["model"] = model
|
||||
# Recreate stream with new model
|
||||
await self._reconnect_if_needed()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def update_options(
|
||||
self,
|
||||
*,
|
||||
languages: Optional[List[Language]] = None,
|
||||
model: Optional[str] = None,
|
||||
enable_automatic_punctuation: Optional[bool] = None,
|
||||
enable_spoken_punctuation: Optional[bool] = None,
|
||||
enable_spoken_emojis: Optional[bool] = None,
|
||||
profanity_filter: Optional[bool] = None,
|
||||
enable_word_time_offsets: Optional[bool] = None,
|
||||
enable_word_confidence: Optional[bool] = None,
|
||||
enable_interim_results: Optional[bool] = None,
|
||||
enable_voice_activity_events: Optional[bool] = None,
|
||||
location: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update service options dynamically.
|
||||
|
||||
Args:
|
||||
languages: New list of recongition languages.
|
||||
model: New recognition model.
|
||||
enable_automatic_punctuation: Enable/disable automatic punctuation.
|
||||
enable_spoken_punctuation: Enable/disable spoken punctuation.
|
||||
enable_spoken_emojis: Enable/disable spoken emojis.
|
||||
profanity_filter: Enable/disable profanity filter.
|
||||
enable_word_time_offsets: Enable/disable word timing info.
|
||||
enable_word_confidence: Enable/disable word confidence scores.
|
||||
enable_interim_results: Enable/disable interim results.
|
||||
enable_voice_activity_events: Enable/disable voice activity detection.
|
||||
location: New Google Cloud location.
|
||||
|
||||
Note:
|
||||
Changes that affect the streaming configuration will cause
|
||||
the stream to be reconnected.
|
||||
"""
|
||||
# Update settings with new values
|
||||
if languages is not None:
|
||||
logger.debug(f"Updating language to: {languages}")
|
||||
self._settings["language_codes"] = [
|
||||
self.language_to_service_language(lang) for lang in languages
|
||||
]
|
||||
|
||||
if model is not None:
|
||||
logger.debug(f"Updating model to: {model}")
|
||||
self._settings["model"] = model
|
||||
|
||||
if enable_automatic_punctuation is not None:
|
||||
logger.debug(f"Updating automatic punctuation to: {enable_automatic_punctuation}")
|
||||
self._settings["enable_automatic_punctuation"] = enable_automatic_punctuation
|
||||
|
||||
if enable_spoken_punctuation is not None:
|
||||
logger.debug(f"Updating spoken punctuation to: {enable_spoken_punctuation}")
|
||||
self._settings["enable_spoken_punctuation"] = enable_spoken_punctuation
|
||||
|
||||
if enable_spoken_emojis is not None:
|
||||
logger.debug(f"Updating spoken emojis to: {enable_spoken_emojis}")
|
||||
self._settings["enable_spoken_emojis"] = enable_spoken_emojis
|
||||
|
||||
if profanity_filter is not None:
|
||||
logger.debug(f"Updating profanity filter to: {profanity_filter}")
|
||||
self._settings["profanity_filter"] = profanity_filter
|
||||
|
||||
if enable_word_time_offsets is not None:
|
||||
logger.debug(f"Updating word time offsets to: {enable_word_time_offsets}")
|
||||
self._settings["enable_word_time_offsets"] = enable_word_time_offsets
|
||||
|
||||
if enable_word_confidence is not None:
|
||||
logger.debug(f"Updating word confidence to: {enable_word_confidence}")
|
||||
self._settings["enable_word_confidence"] = enable_word_confidence
|
||||
|
||||
if enable_interim_results is not None:
|
||||
logger.debug(f"Updating interim results to: {enable_interim_results}")
|
||||
self._settings["enable_interim_results"] = enable_interim_results
|
||||
|
||||
if enable_voice_activity_events is not None:
|
||||
logger.debug(f"Updating voice activity events to: {enable_voice_activity_events}")
|
||||
self._settings["enable_voice_activity_events"] = enable_voice_activity_events
|
||||
|
||||
if location is not None:
|
||||
logger.debug(f"Updating location to: {location}")
|
||||
self._location = location
|
||||
|
||||
# Reconnect the stream for updates
|
||||
await self._reconnect_if_needed()
|
||||
|
||||
async def _connect(self):
|
||||
"""Initialize streaming recognition config and stream."""
|
||||
logger.debug("Connecting to Google Speech-to-Text")
|
||||
|
||||
# Set stream start time
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
self._new_stream = True
|
||||
|
||||
self._config = cloud_speech.StreamingRecognitionConfig(
|
||||
config=cloud_speech.RecognitionConfig(
|
||||
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|
||||
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
language_codes=self._settings["language_codes"],
|
||||
model=self._settings["model"],
|
||||
features=cloud_speech.RecognitionFeatures(
|
||||
enable_automatic_punctuation=self._settings["enable_automatic_punctuation"],
|
||||
enable_spoken_punctuation=self._settings["enable_spoken_punctuation"],
|
||||
enable_spoken_emojis=self._settings["enable_spoken_emojis"],
|
||||
profanity_filter=self._settings["profanity_filter"],
|
||||
enable_word_time_offsets=self._settings["enable_word_time_offsets"],
|
||||
enable_word_confidence=self._settings["enable_word_confidence"],
|
||||
),
|
||||
),
|
||||
streaming_features=cloud_speech.StreamingRecognitionFeatures(
|
||||
enable_voice_activity_events=self._settings["enable_voice_activity_events"],
|
||||
interim_results=self._settings["enable_interim_results"],
|
||||
),
|
||||
)
|
||||
|
||||
self._streaming_task = self.create_task(self._stream_audio())
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Clean up streaming recognition resources."""
|
||||
if self._streaming_task:
|
||||
logger.debug("Disconnecting from Google Speech-to-Text")
|
||||
# Send sentinel value to stop request generator
|
||||
await self._request_queue.put(None)
|
||||
await self.cancel_task(self._streaming_task)
|
||||
self._streaming_task = None
|
||||
# Clear any remaining items in the queue
|
||||
while not self._request_queue.empty():
|
||||
try:
|
||||
self._request_queue.get_nowait()
|
||||
self._request_queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
async def _request_generator(self):
|
||||
"""Generates requests for the streaming recognize method."""
|
||||
recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_"
|
||||
logger.trace(f"Using recognizer path: {recognizer_path}")
|
||||
|
||||
try:
|
||||
# Send initial config
|
||||
yield cloud_speech.StreamingRecognizeRequest(
|
||||
recognizer=recognizer_path,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio_data = await self._request_queue.get()
|
||||
if audio_data is None: # Sentinel value to stop
|
||||
break
|
||||
|
||||
# Check streaming limit
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Streaming limit reached, initiating graceful reconnection")
|
||||
# Instead of immediate reconnection, we'll break and let the stream close naturally
|
||||
self._last_audio_input = self._audio_input
|
||||
self._audio_input = []
|
||||
self._restart_counter += 1
|
||||
# Put the current audio chunk back in the queue
|
||||
await self._request_queue.put(audio_data)
|
||||
break
|
||||
|
||||
self._audio_input.append(audio_data)
|
||||
yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
self._request_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in request generator: {e}")
|
||||
raise
|
||||
|
||||
async def _stream_audio(self):
|
||||
"""Handle bi-directional streaming with Google STT."""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Start bi-directional streaming
|
||||
streaming_recognize = await self._client.streaming_recognize(
|
||||
requests=self._request_generator()
|
||||
)
|
||||
|
||||
# Process responses
|
||||
await self._process_responses(streaming_recognize)
|
||||
|
||||
# If we're here, check if we need to reconnect
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Reconnecting stream after timeout")
|
||||
# Reset stream start time
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
else:
|
||||
# Normal stream end
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self} Reconnecting: {e}")
|
||||
|
||||
await asyncio.sleep(1) # Brief delay before reconnecting
|
||||
self._stream_start_time = int(time.time() * 1000)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming task: {e}")
|
||||
await self.push_frame(ErrorFrame(str(e)))
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription."""
|
||||
if self._streaming_task:
|
||||
# Queue the audio data
|
||||
await self._request_queue.put(audio)
|
||||
yield None
|
||||
|
||||
async def _process_responses(self, streaming_recognize):
|
||||
"""Process streaming recognition responses."""
|
||||
try:
|
||||
async for response in streaming_recognize:
|
||||
# Check streaming limit
|
||||
if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
|
||||
logger.debug("Stream timeout reached in response processing")
|
||||
break
|
||||
|
||||
if not response.results:
|
||||
continue
|
||||
|
||||
for result in response.results:
|
||||
if not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if not transcript:
|
||||
continue
|
||||
|
||||
primary_language = self._settings["language_codes"][0]
|
||||
|
||||
if result.is_final:
|
||||
self._last_transcript_was_final = True
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
|
||||
)
|
||||
else:
|
||||
self._last_transcript_was_final = False
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript, "", time_now_iso8601(), primary_language
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google STT responses: {e}")
|
||||
|
||||
# Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
|
||||
raise
|
||||
364
src/pipecat/services/google/tts.py
Normal file
364
src/pipecat/services/google/tts.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
from google.cloud import texttospeech_v1
|
||||
from google.oauth2 import service_account
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_google_tts_language(language: Language) -> Optional[str]:
|
||||
language_map = {
|
||||
# Afrikaans
|
||||
Language.AF: "af-ZA",
|
||||
Language.AF_ZA: "af-ZA",
|
||||
# Arabic
|
||||
Language.AR: "ar-XA",
|
||||
# Bengali
|
||||
Language.BN: "bn-IN",
|
||||
Language.BN_IN: "bn-IN",
|
||||
# Bulgarian
|
||||
Language.BG: "bg-BG",
|
||||
Language.BG_BG: "bg-BG",
|
||||
# Catalan
|
||||
Language.CA: "ca-ES",
|
||||
Language.CA_ES: "ca-ES",
|
||||
# Chinese (Mandarin and Cantonese)
|
||||
Language.ZH: "cmn-CN",
|
||||
Language.ZH_CN: "cmn-CN",
|
||||
Language.ZH_TW: "cmn-TW",
|
||||
Language.ZH_HK: "yue-HK",
|
||||
# Czech
|
||||
Language.CS: "cs-CZ",
|
||||
Language.CS_CZ: "cs-CZ",
|
||||
# Danish
|
||||
Language.DA: "da-DK",
|
||||
Language.DA_DK: "da-DK",
|
||||
# Dutch
|
||||
Language.NL: "nl-NL",
|
||||
Language.NL_BE: "nl-BE",
|
||||
Language.NL_NL: "nl-NL",
|
||||
# English
|
||||
Language.EN: "en-US",
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_AU: "en-AU",
|
||||
Language.EN_GB: "en-GB",
|
||||
Language.EN_IN: "en-IN",
|
||||
# Estonian
|
||||
Language.ET: "et-EE",
|
||||
Language.ET_EE: "et-EE",
|
||||
# Filipino
|
||||
Language.FIL: "fil-PH",
|
||||
Language.FIL_PH: "fil-PH",
|
||||
# Finnish
|
||||
Language.FI: "fi-FI",
|
||||
Language.FI_FI: "fi-FI",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_CA: "fr-CA",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# Galician
|
||||
Language.GL: "gl-ES",
|
||||
Language.GL_ES: "gl-ES",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Greek
|
||||
Language.EL: "el-GR",
|
||||
Language.EL_GR: "el-GR",
|
||||
# Gujarati
|
||||
Language.GU: "gu-IN",
|
||||
Language.GU_IN: "gu-IN",
|
||||
# Hebrew
|
||||
Language.HE: "he-IL",
|
||||
Language.HE_IL: "he-IL",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Hungarian
|
||||
Language.HU: "hu-HU",
|
||||
Language.HU_HU: "hu-HU",
|
||||
# Icelandic
|
||||
Language.IS: "is-IS",
|
||||
Language.IS_IS: "is-IS",
|
||||
# Indonesian
|
||||
Language.ID: "id-ID",
|
||||
Language.ID_ID: "id-ID",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Kannada
|
||||
Language.KN: "kn-IN",
|
||||
Language.KN_IN: "kn-IN",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Latvian
|
||||
Language.LV: "lv-LV",
|
||||
Language.LV_LV: "lv-LV",
|
||||
# Lithuanian
|
||||
Language.LT: "lt-LT",
|
||||
Language.LT_LT: "lt-LT",
|
||||
# Malay
|
||||
Language.MS: "ms-MY",
|
||||
Language.MS_MY: "ms-MY",
|
||||
# Malayalam
|
||||
Language.ML: "ml-IN",
|
||||
Language.ML_IN: "ml-IN",
|
||||
# Marathi
|
||||
Language.MR: "mr-IN",
|
||||
Language.MR_IN: "mr-IN",
|
||||
# Norwegian
|
||||
Language.NO: "nb-NO",
|
||||
Language.NB: "nb-NO",
|
||||
Language.NB_NO: "nb-NO",
|
||||
# Polish
|
||||
Language.PL: "pl-PL",
|
||||
Language.PL_PL: "pl-PL",
|
||||
# Portuguese
|
||||
Language.PT: "pt-PT",
|
||||
Language.PT_BR: "pt-BR",
|
||||
Language.PT_PT: "pt-PT",
|
||||
# Punjabi
|
||||
Language.PA: "pa-IN",
|
||||
Language.PA_IN: "pa-IN",
|
||||
# Romanian
|
||||
Language.RO: "ro-RO",
|
||||
Language.RO_RO: "ro-RO",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Serbian
|
||||
Language.SR: "sr-RS",
|
||||
Language.SR_RS: "sr-RS",
|
||||
# Slovak
|
||||
Language.SK: "sk-SK",
|
||||
Language.SK_SK: "sk-SK",
|
||||
# Spanish
|
||||
Language.ES: "es-ES",
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US",
|
||||
# Swedish
|
||||
Language.SV: "sv-SE",
|
||||
Language.SV_SE: "sv-SE",
|
||||
# Tamil
|
||||
Language.TA: "ta-IN",
|
||||
Language.TA_IN: "ta-IN",
|
||||
# Telugu
|
||||
Language.TE: "te-IN",
|
||||
Language.TE_IN: "te-IN",
|
||||
# Thai
|
||||
Language.TH: "th-TH",
|
||||
Language.TH_TH: "th-TH",
|
||||
# Turkish
|
||||
Language.TR: "tr-TR",
|
||||
Language.TR_TR: "tr-TR",
|
||||
# Ukrainian
|
||||
Language.UK: "uk-UA",
|
||||
Language.UK_UA: "uk-UA",
|
||||
# Vietnamese
|
||||
Language.VI: "vi-VN",
|
||||
Language.VI_VN: "vi-VN",
|
||||
}
|
||||
|
||||
return language_map.get(language)
|
||||
|
||||
|
||||
class GoogleTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
gender: Optional[Literal["male", "female", "neutral"]] = None
|
||||
google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
voice_id: str = "en-US-Neural2-A",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en-US",
|
||||
"gender": params.gender,
|
||||
"google_style": params.google_style,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
|
||||
def _create_client(
|
||||
self, credentials: Optional[str], credentials_path: Optional[str]
|
||||
) -> texttospeech_v1.TextToSpeechAsyncClient:
|
||||
creds: Optional[service_account.Credentials] = None
|
||||
|
||||
# Create a Google Cloud service account for the Cloud Text-to-Speech API
|
||||
# Using either the provided credentials JSON string or the path to a service account JSON
|
||||
# file, create a Google Cloud service account and use it to authenticate with the API.
|
||||
if credentials:
|
||||
# Use provided credentials JSON string
|
||||
json_account_info = json.loads(credentials)
|
||||
creds = service_account.Credentials.from_service_account_info(json_account_info)
|
||||
elif credentials_path:
|
||||
# Use service account JSON file if provided
|
||||
creds = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
return language_to_google_tts_language(language)
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
# Voice tag
|
||||
voice_attrs = [f"name='{self._voice_id}'"]
|
||||
|
||||
language = self._settings["language"]
|
||||
voice_attrs.append(f"language='{language}'")
|
||||
|
||||
if self._settings["gender"]:
|
||||
voice_attrs.append(f"gender='{self._settings['gender']}'")
|
||||
ssml += f"<voice {' '.join(voice_attrs)}>"
|
||||
|
||||
# Prosody tag
|
||||
prosody_attrs = []
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
# Emphasis tag
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
# Google style tag
|
||||
if self._settings["google_style"]:
|
||||
ssml += f"<google:style name='{self._settings['google_style']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
# Close tags
|
||||
if self._settings["google_style"]:
|
||||
ssml += "</google:style>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
|
||||
is_chirp_voice = "chirp" in self._voice_id.lower()
|
||||
is_journey_voice = "journey" in self._voice_id.lower()
|
||||
|
||||
# Create synthesis input based on voice_id
|
||||
if is_chirp_voice or is_journey_voice:
|
||||
# Chirp and Journey voices don't support SSML, use plain text
|
||||
synthesis_input = texttospeech_v1.SynthesisInput(text=text)
|
||||
else:
|
||||
ssml = self._construct_ssml(text)
|
||||
synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
|
||||
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
audio_config = texttospeech_v1.AudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
)
|
||||
|
||||
request = texttospeech_v1.SynthesizeSpeechRequest(
|
||||
input=synthesis_input, voice=voice, audio_config=audio_config
|
||||
)
|
||||
|
||||
response = await self._client.synthesize_speech(request=request)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Skip the first 44 bytes to remove the WAV header
|
||||
audio_content = response.audio_content[44:]
|
||||
|
||||
# Read and yield audio data in chunks
|
||||
chunk_size = 8192
|
||||
for i in range(0, len(audio_content), chunk_size):
|
||||
chunk = audio_content[i : i + chunk_size]
|
||||
if not chunk:
|
||||
break
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
await asyncio.sleep(0) # Allow other tasks to run
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
13
src/pipecat/services/grok/__init__.py
Normal file
13
src/pipecat/services/grok/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "grok", "grok.llm")
|
||||
@@ -4,21 +4,14 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import FunctionCallResultProperties
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai import (
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMService,
|
||||
OpenAIUserContextAggregator,
|
||||
@@ -27,13 +20,13 @@ from pipecat.services.openai import (
|
||||
|
||||
@dataclass
|
||||
class GrokContextAggregatorPair:
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "OpenAIAssistantContextAggregator"
|
||||
_user: OpenAIUserContextAggregator
|
||||
_assistant: OpenAIAssistantContextAggregator
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
def user(self) -> OpenAIUserContextAggregator:
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAssistantContextAggregator":
|
||||
def assistant(self) -> OpenAIAssistantContextAggregator:
|
||||
return self._assistant
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
from groq import AsyncGroq
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GroqLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Groq's API using the OpenAI-compatible interface.
|
||||
|
||||
This service extends OpenAILLMService to connect to Groq's API endpoint while
|
||||
maintaining full compatibility with OpenAI's interface and functionality.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing Groq's API
|
||||
base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1"
|
||||
model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.groq.com/openai/v1",
|
||||
model: str = "llama-3.3-70b-versatile",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
"""Create OpenAI-compatible client for Groq API endpoint."""
|
||||
logger.debug(f"Creating Groq client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
|
||||
class GroqSTTService(BaseWhisperSTTService):
|
||||
"""Groq Whisper speech-to-text service.
|
||||
|
||||
Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
|
||||
set via the api_key parameter or GROQ_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
|
||||
api_key: Groq API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "whisper-large-v3-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.groq.com/openai/v1",
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
assert self._language is not None # Assigned in the BaseWhisperSTTService class
|
||||
|
||||
# Build kwargs dict with only set parameters
|
||||
kwargs = {
|
||||
"file": ("audio.wav", audio, "audio/wav"),
|
||||
"model": self.model_name,
|
||||
"response_format": "json",
|
||||
"language": self._language,
|
||||
}
|
||||
|
||||
if self._prompt is not None:
|
||||
kwargs["prompt"] = self._prompt
|
||||
|
||||
if self._temperature is not None:
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
|
||||
|
||||
class GroqTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
seed: Optional[int] = None
|
||||
|
||||
GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
output_format: str = "wav",
|
||||
params: InputParams = InputParams(),
|
||||
model_name: str = "playai-tts",
|
||||
voice_id: str = "Celeste-PlayAI",
|
||||
sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate != self.GROQ_SAMPLE_RATE:
|
||||
logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._output_format = output_format
|
||||
self._voice_id = voice_id
|
||||
self._params = params
|
||||
|
||||
self._client = AsyncGroq(api_key=self._api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
measuring_ttfb = True
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
response = await self._client.audio.speech.create(
|
||||
model=self._model_name,
|
||||
voice=self._voice_id,
|
||||
response_format=self._output_format,
|
||||
input=text,
|
||||
)
|
||||
|
||||
async for data in response.iter_bytes():
|
||||
if measuring_ttfb:
|
||||
await self.stop_ttfb_metrics()
|
||||
measuring_ttfb = False
|
||||
# remove wav header if present
|
||||
if data.startswith(b"RIFF"):
|
||||
data = data[44:]
|
||||
if len(data) == 0:
|
||||
continue
|
||||
yield TTSAudioRawFrame(data, self.sample_rate, 1)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
15
src/pipecat/services/groq/__init__.py
Normal file
15
src/pipecat/services/groq/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "groq", "groq.[llm,stt,tts]")
|
||||
38
src/pipecat/services/groq/llm.py
Normal file
38
src/pipecat/services/groq/llm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class GroqLLMService(OpenAILLMService):
|
||||
"""A service for interacting with Groq's API using the OpenAI-compatible interface.
|
||||
|
||||
This service extends OpenAILLMService to connect to Groq's API endpoint while
|
||||
maintaining full compatibility with OpenAI's interface and functionality.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing Groq's API
|
||||
base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1"
|
||||
model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.groq.com/openai/v1",
|
||||
model: str = "llama-3.3-70b-versatile",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
"""Create OpenAI-compatible client for Groq API endpoint."""
|
||||
logger.debug(f"Creating Groq client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
67
src/pipecat/services/groq/stt.py
Normal file
67
src/pipecat/services/groq/stt.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class GroqSTTService(BaseWhisperSTTService):
|
||||
"""Groq Whisper speech-to-text service.
|
||||
|
||||
Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
|
||||
set via the api_key parameter or GROQ_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
|
||||
api_key: Groq API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "whisper-large-v3-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.groq.com/openai/v1",
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
assert self._language is not None # Assigned in the BaseWhisperSTTService class
|
||||
|
||||
# Build kwargs dict with only set parameters
|
||||
kwargs = {
|
||||
"file": ("audio.wav", audio, "audio/wav"),
|
||||
"model": self.model_name,
|
||||
"response_format": "json",
|
||||
"language": self._language,
|
||||
}
|
||||
|
||||
if self._prompt is not None:
|
||||
kwargs["prompt"] = self._prompt
|
||||
|
||||
if self._temperature is not None:
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
86
src/pipecat/services/groq/tts.py
Normal file
86
src/pipecat/services/groq/tts.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
from groq import AsyncGroq
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Groq, you need to `pip install pipecat-ai[groq]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class GroqTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
seed: Optional[int] = None
|
||||
|
||||
GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
output_format: str = "wav",
|
||||
params: InputParams = InputParams(),
|
||||
model_name: str = "playai-tts",
|
||||
voice_id: str = "Celeste-PlayAI",
|
||||
sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate != self.GROQ_SAMPLE_RATE:
|
||||
logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
sample_rate=sample_rate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._output_format = output_format
|
||||
self._voice_id = voice_id
|
||||
self._params = params
|
||||
|
||||
self._client = AsyncGroq(api_key=self._api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
measuring_ttfb = True
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
response = await self._client.audio.speech.create(
|
||||
model=self._model_name,
|
||||
voice=self._voice_id,
|
||||
response_format=self._output_format,
|
||||
input=text,
|
||||
)
|
||||
|
||||
async for data in response.iter_bytes():
|
||||
if measuring_ttfb:
|
||||
await self.stop_ttfb_metrics()
|
||||
measuring_ttfb = False
|
||||
# remove wav header if present
|
||||
if data.startswith(b"RIFF"):
|
||||
data = data[44:]
|
||||
if len(data) == 0:
|
||||
continue
|
||||
yield TTSAudioRawFrame(data, self.sample_rate, 1)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
13
src/pipecat/services/lmnt/__init__.py
Normal file
13
src/pipecat/services/lmnt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "lmnt", "lmnt.tts")
|
||||
@@ -29,9 +29,7 @@ try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`. Also, set `LMNT_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/mem0/__init__.py
Normal file
13
src/pipecat/services/mem0/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .memory import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "mem0", "mem0.memory")
|
||||
13
src/pipecat/services/moondream/__init__.py
Normal file
13
src/pipecat/services/moondream/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .vision import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "moondream", "moondream.vision")
|
||||
13
src/pipecat/services/neuphonic/__init__.py
Normal file
13
src/pipecat/services/neuphonic/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "neuphonic", "neuphonic.tts")
|
||||
@@ -30,15 +30,12 @@ from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import InterruptibleTTSService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for Neuphonic configuration needed
|
||||
try:
|
||||
import websockets
|
||||
from pyneuphonic import Neuphonic, TTSConfig
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`. Also, set `NEUPHONIC_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/nim/__init__.py
Normal file
13
src/pipecat/services/nim/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "nim", "nim.llm")
|
||||
@@ -4,10 +4,9 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class NimLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/ollama/__init__.py
Normal file
13
src/pipecat/services/ollama/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ollama", "ollama.llm")
|
||||
@@ -4,9 +4,9 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class OLLamaLLMService(BaseOpenAILLMService):
|
||||
class OLLamaLLMService(OpenAILLMService):
|
||||
def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
|
||||
super().__init__(model=model, base_url=base_url, api_key="ollama")
|
||||
@@ -1,644 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from openai import (
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
BadRequestError,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import (
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
||||
|
||||
VALID_VOICES: Dict[str, ValidVoice] = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseOpenAILLMService(LLMService):
|
||||
"""This is the base for all services that use the AsyncOpenAI client.
|
||||
|
||||
This service consumes OpenAILLMContextFrame frames, which contain a reference
|
||||
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
|
||||
sent to the LLM for a completion. This includes user, assistant and system messages
|
||||
as well as tool choices and the tool, which is used if requesting function
|
||||
calls from the LLM.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
|
||||
)
|
||||
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
|
||||
# Note: top_k is currently not supported by the OpenAI client library,
|
||||
# so top_k is ignored right now.
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = {
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"max_tokens": params.max_tokens,
|
||||
"max_completion_tokens": params.max_completion_tokens,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._client = self.create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
default_headers=default_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
default_headers=None,
|
||||
**kwargs,
|
||||
):
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
|
||||
)
|
||||
),
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"seed": self._settings["seed"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"max_completion_tokens": self._settings["max_completion_tokens"],
|
||||
}
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
# base64 encode any images
|
||||
for message in messages:
|
||||
if message.get("mime_type") == "image/jpeg":
|
||||
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
|
||||
text = message["content"]
|
||||
message["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
},
|
||||
]
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
chunks = await self.get_chat_completions(context, messages)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
|
||||
context
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = False
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAILLMContext = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self._process_context(context)
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIContextAggregatorPair:
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "OpenAIAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
model: str = "dall-e-3",
|
||||
):
|
||||
super().__init__()
|
||||
self.set_model_name(model)
|
||||
self._image_size = image_size
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
image = await self._client.images.generate(
|
||||
prompt=prompt, model=self.model_name, n=1, size=self._image_size
|
||||
)
|
||||
|
||||
image_url = image.data[0].url
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} No image provided in response: {image}")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
# Load the image from the url
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
|
||||
yield frame
|
||||
|
||||
|
||||
class OpenAISTTService(BaseWhisperSTTService):
|
||||
"""OpenAI Speech-to-Text service that generates text from audio.
|
||||
|
||||
Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key
|
||||
set via the api_key parameter or OPENAI_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe".
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to None.
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o-transcribe",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
assert self._language is not None # Assigned in the BaseWhisperSTTService class
|
||||
|
||||
# Build kwargs dict with only set parameters
|
||||
kwargs = {
|
||||
"file": ("audio.wav", audio, "audio/wav"),
|
||||
"model": self.model_name,
|
||||
"language": self._language,
|
||||
}
|
||||
|
||||
if self._prompt is not None:
|
||||
kwargs["prompt"] = self._prompt
|
||||
|
||||
if self._temperature is not None:
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
|
||||
|
||||
class OpenAITTSService(TTSService):
|
||||
"""OpenAI Text-to-Speech service that generates audio from text.
|
||||
|
||||
This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
voice: Voice ID to use. Defaults to "alloy".
|
||||
model: TTS model to use. Defaults to "gpt-4o-mini-tts".
|
||||
sample_rate: Output audio sample rate in Hz. Defaults to None.
|
||||
**kwargs: Additional keyword arguments passed to TTSService.
|
||||
|
||||
The service returns PCM-encoded audio at the specified sample rate.
|
||||
|
||||
"""
|
||||
|
||||
OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
voice: str = "alloy",
|
||||
model: str = "gpt-4o-mini-tts",
|
||||
sample_rate: Optional[int] = None,
|
||||
instructions: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {self.sample_rate}Hz may cause issues."
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice)
|
||||
self._instructions = instructions
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.info(f"Switching TTS model to: [{model}]")
|
||||
self.set_model_name(model)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if self.sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {self.sample_rate}Hz may cause issues."
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Setup extra body parameters
|
||||
extra_body = {}
|
||||
if self._instructions:
|
||||
extra_body["instructions"] = self._instructions
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text or " ", # Text must contain at least one character
|
||||
model=self.model_name,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
extra_body=extra_body,
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
error = await r.text()
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
CHUNK_SIZE = 1024
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in r.iter_bytes(CHUNK_SIZE):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
16
src/pipecat/services/openai/__init__.py
Normal file
16
src/pipecat/services/openai/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .image import *
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openai", "openai.[image,llm,stt,tts]")
|
||||
296
src/pipecat/services/openai/base_llm.py
Normal file
296
src/pipecat/services/openai/base_llm.py
Normal file
@@ -0,0 +1,296 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from openai import (
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseOpenAILLMService(LLMService):
|
||||
"""This is the base for all services that use the AsyncOpenAI client.
|
||||
|
||||
This service consumes OpenAILLMContextFrame frames, which contain a reference
|
||||
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
|
||||
sent to the LLM for a completion. This includes user, assistant and system messages
|
||||
as well as tool choices and the tool, which is used if requesting function
|
||||
calls from the LLM.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
|
||||
)
|
||||
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
|
||||
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
|
||||
# Note: top_k is currently not supported by the OpenAI client library,
|
||||
# so top_k is ignored right now.
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
|
||||
max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = {
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"max_tokens": params.max_tokens,
|
||||
"max_completion_tokens": params.max_completion_tokens,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._client = self.create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
default_headers=default_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
default_headers=None,
|
||||
**kwargs,
|
||||
):
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
|
||||
)
|
||||
),
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"seed": self._settings["seed"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"max_completion_tokens": self._settings["max_completion_tokens"],
|
||||
}
|
||||
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
# base64 encode any images
|
||||
for message in messages:
|
||||
if message.get("mime_type") == "image/jpeg":
|
||||
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
|
||||
text = message["content"]
|
||||
message["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
},
|
||||
]
|
||||
del message["data"]
|
||||
del message["mime_type"]
|
||||
|
||||
chunks = await self.get_chat_completions(context, messages)
|
||||
|
||||
return chunks
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
|
||||
context
|
||||
)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = False
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAILLMContext = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = OpenAILLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext()
|
||||
context.add_image_frame_message(
|
||||
format=frame.format, size=frame.size, image=frame.image, text=frame.text
|
||||
)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self._process_context(context)
|
||||
except httpx.TimeoutException:
|
||||
await self._call_event_handler("on_completion_timeout")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
58
src/pipecat/services/openai/image.py
Normal file
58
src/pipecat/services/openai/image.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import io
|
||||
from typing import AsyncGenerator, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService
|
||||
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
model: str = "dall-e-3",
|
||||
):
|
||||
super().__init__()
|
||||
self.set_model_name(model)
|
||||
self._image_size = image_size
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating image from prompt: {prompt}")
|
||||
|
||||
image = await self._client.images.generate(
|
||||
prompt=prompt, model=self.model_name, n=1, size=self._image_size
|
||||
)
|
||||
|
||||
image_url = image.data[0].url
|
||||
|
||||
if not image_url:
|
||||
logger.error(f"{self} No image provided in response: {image}")
|
||||
yield ErrorFrame("Image generation failed")
|
||||
return
|
||||
|
||||
# Load the image from the url
|
||||
async with self._aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
|
||||
yield frame
|
||||
142
src/pipecat/services/openai/llm.py
Normal file
142
src/pipecat/services/openai/llm.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallCancelFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIContextAggregatorPair:
|
||||
_user: "OpenAIUserContextAggregator"
|
||||
_assistant: "OpenAIAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "OpenAIUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "OpenAIAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
context: OpenAILLMContext,
|
||||
*,
|
||||
user_kwargs: Mapping[str, Any] = {},
|
||||
assistant_kwargs: Mapping[str, Any] = {},
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
"""Create an instance of OpenAIContextAggregatorPair from an
|
||||
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||
assistant aggregators can be provided.
|
||||
|
||||
Args:
|
||||
context (OpenAILLMContext): The LLM context.
|
||||
user_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the user context aggregator constructor. Defaults
|
||||
to an empty mapping.
|
||||
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
|
||||
arguments for the assistant context aggregator
|
||||
constructor. Defaults to an empty mapping.
|
||||
|
||||
Returns:
|
||||
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||
the user and one for the assistant, encapsulated in an
|
||||
OpenAIContextAggregatorPair.
|
||||
|
||||
"""
|
||||
context.set_llm_adapter(self.get_llm_adapter())
|
||||
user = OpenAIUserContextAggregator(context, **user_kwargs)
|
||||
assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
if frame.result:
|
||||
result = json.dumps(frame.result)
|
||||
await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
|
||||
else:
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "COMPLETED"
|
||||
)
|
||||
|
||||
async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.function_name, frame.tool_call_id, "CANCELLED"
|
||||
)
|
||||
|
||||
async def _update_function_call_result(
|
||||
self, function_name: str, tool_call_id: str, result: Any
|
||||
):
|
||||
for message in self._context.messages:
|
||||
if (
|
||||
message["role"] == "tool"
|
||||
and message["tool_call_id"]
|
||||
and message["tool_call_id"] == tool_call_id
|
||||
):
|
||||
message["content"] = result
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
await self._update_function_call_result(
|
||||
frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
|
||||
)
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.request.context,
|
||||
)
|
||||
66
src/pipecat/services/openai/stt.py
Normal file
66
src/pipecat/services/openai/stt.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class OpenAISTTService(BaseWhisperSTTService):
|
||||
"""OpenAI Speech-to-Text service that generates text from audio.
|
||||
|
||||
Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key
|
||||
set via the api_key parameter or OPENAI_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe".
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to None.
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o-transcribe",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
assert self._language is not None # Assigned in the BaseWhisperSTTService class
|
||||
|
||||
# Build kwargs dict with only set parameters
|
||||
kwargs = {
|
||||
"file": ("audio.wav", audio, "audio/wav"),
|
||||
"model": self.model_name,
|
||||
"language": self._language,
|
||||
}
|
||||
|
||||
if self._prompt is not None:
|
||||
kwargs["prompt"] = self._prompt
|
||||
|
||||
if self._temperature is not None:
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
129
src/pipecat/services/openai/tts.py
Normal file
129
src/pipecat/services/openai/tts.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
||||
|
||||
VALID_VOICES: Dict[str, ValidVoice] = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
|
||||
|
||||
class OpenAITTSService(TTSService):
|
||||
"""OpenAI Text-to-Speech service that generates audio from text.
|
||||
|
||||
This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
voice: Voice ID to use. Defaults to "alloy".
|
||||
model: TTS model to use. Defaults to "gpt-4o-mini-tts".
|
||||
sample_rate: Output audio sample rate in Hz. Defaults to None.
|
||||
**kwargs: Additional keyword arguments passed to TTSService.
|
||||
|
||||
The service returns PCM-encoded audio at the specified sample rate.
|
||||
|
||||
"""
|
||||
|
||||
OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
voice: str = "alloy",
|
||||
model: str = "gpt-4o-mini-tts",
|
||||
sample_rate: Optional[int] = None,
|
||||
instructions: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {self.sample_rate}Hz may cause issues."
|
||||
)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice)
|
||||
self._instructions = instructions
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.info(f"Switching TTS model to: [{model}]")
|
||||
self.set_model_name(model)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if self.sample_rate != self.OPENAI_SAMPLE_RATE:
|
||||
logger.warning(
|
||||
f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
|
||||
f"Current rate of {self.sample_rate}Hz may cause issues."
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Setup extra body parameters
|
||||
extra_body = {}
|
||||
if self._instructions:
|
||||
extra_body["instructions"] = self._instructions
|
||||
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text or " ", # Text must contain at least one character
|
||||
model=self.model_name,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
extra_body=extra_body,
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
error = await r.text()
|
||||
logger.error(
|
||||
f"{self} error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
yield ErrorFrame(
|
||||
f"Error getting audio (status: {r.status_code}, error: {error})"
|
||||
)
|
||||
return
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
CHUNK_SIZE = 1024
|
||||
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in r.iter_bytes(CHUNK_SIZE):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
@@ -4,8 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .openai import OpenAIRealtimeBetaLLMService
|
||||
|
||||
@@ -6,23 +6,18 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallResultProperties,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai import (
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAIUserContextAggregator,
|
||||
)
|
||||
|
||||
@@ -12,8 +12,6 @@ from typing import Any, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -23,6 +21,7 @@ except ModuleNotFoundError as e:
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
@@ -55,7 +54,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import OpenAIContextAggregatorPair
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from . import events
|
||||
|
||||
13
src/pipecat/services/openpipe/__init__.py
Normal file
13
src/pipecat/services/openpipe/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openpipe", "openpipe.llm")
|
||||
@@ -10,16 +10,14 @@ from loguru import logger
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
try:
|
||||
from openpipe import AsyncOpenAI as OpenPipeAI
|
||||
from openpipe import AsyncStream
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`. Also, set `OPENPIPE_API_KEY` and `OPENAI_API_KEY` environment variables."
|
||||
)
|
||||
logger.error("In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/openrouter/__init__.py
Normal file
13
src/pipecat/services/openrouter/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openrouter", "openrouter.llm")
|
||||
@@ -8,7 +8,7 @@ from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class OpenRouterLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/perplexity/__init__.py
Normal file
13
src/pipecat/services/perplexity/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "perplexity", "perplexity.llm")
|
||||
@@ -6,13 +6,12 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from openai import NOT_GIVEN, AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class PerplexityLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/piper/__init__.py
Normal file
13
src/pipecat/services/piper/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "piper", "piper.tts")
|
||||
13
src/pipecat/services/playht/__init__.py
Normal file
13
src/pipecat/services/playht/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "playht", "playht.tts")
|
||||
@@ -36,9 +36,7 @@ try:
|
||||
from pyht.client import Language as PlayHTLanguage
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use PlayHT, you need to `pip install pipecat-ai[playht]`. Also, set `PLAY_HT_USER_ID` and `PLAY_HT_API_KEY` environment variables."
|
||||
)
|
||||
logger.error("In order to use PlayHT, you need to `pip install pipecat-ai[playht]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
13
src/pipecat/services/qwen/__init__.py
Normal file
13
src/pipecat/services/qwen/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "qwen", "qwen.llm")
|
||||
@@ -4,11 +4,9 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAISTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class QwenLLMService(OpenAILLMService):
|
||||
13
src/pipecat/services/rime/__init__.py
Normal file
13
src/pipecat/services/rime/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "rime", "rime.tts")
|
||||
@@ -34,9 +34,7 @@ try:
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use Rime, you need to `pip install pipecat-ai[rime]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
14
src/pipecat/services/riva/__init__.py
Normal file
14
src/pipecat/services/riva/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .stt import *
|
||||
from .tts import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "riva", "riva.[stt,tts]")
|
||||
@@ -17,11 +17,8 @@ from pipecat.frames.frames import (
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -30,100 +27,9 @@ try:
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use nvidia riva TTS or STT, you need to `pip install pipecat-ai[riva]`. Also, set `NVIDIA_API_KEY` environment variable."
|
||||
)
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
FASTPITCH_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class FastPitchTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
|
||||
self.set_model_name("fastpitch-hifigan-tts")
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
audio_prompt_file=None,
|
||||
quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class ParakeetSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
117
src/pipecat/services/riva/tts.py
Normal file
117
src/pipecat/services/riva/tts.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
FASTPITCH_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class FastPitchTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
|
||||
self.set_model_name("fastpitch-hifigan-tts")
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
audio_prompt_file=None,
|
||||
quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
13
src/pipecat/services/simli/__init__.py
Normal file
13
src/pipecat/services/simli/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .video import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "simli", "simli.video")
|
||||
13
src/pipecat/services/tavus/__init__.py
Normal file
13
src/pipecat/services/tavus/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .video import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "tavus", "tavus.video")
|
||||
@@ -4,7 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
"""This module implements Tavus as a sink transport layer"""
|
||||
|
||||
import base64
|
||||
@@ -1,69 +0,0 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from services.ai_service import AIService
|
||||
|
||||
# Note that Cloudflare's AI workers are still in beta.
|
||||
# https://developers.cloudflare.com/workers-ai/
|
||||
|
||||
|
||||
class CloudflareAIService(AIService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
|
||||
self.cloudflare_api_token = os.getenv("CLOUDFLARE_API_TOKEN")
|
||||
|
||||
self.api_base_url = (
|
||||
f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/ai/run/"
|
||||
)
|
||||
self.headers = {"Authorization": f"Bearer {self.cloudflare_api_token}"}
|
||||
|
||||
# base endpoint, used by the others
|
||||
def run(self, model, input):
|
||||
response = requests.post(f"{self.api_base_url}{model}", headers=self.headers, json=input)
|
||||
return response.json()
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/llm/
|
||||
def run_llm(self, messages, latest_user_message=None, stream=True):
|
||||
input = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a friendly assistant"},
|
||||
{"role": "user", "content": sentence},
|
||||
]
|
||||
}
|
||||
|
||||
return self.run("@cf/meta/llama-2-7b-chat-int8", input)
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/translation/
|
||||
def run_text_translation(self, sentence, source_language, target_language):
|
||||
return self.run(
|
||||
"@cf/meta/m2m100-1.2b",
|
||||
{"text": sentence, "source_lang": source_language, "target_lang": target_language},
|
||||
)
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/sentiment-analysis/
|
||||
def run_text_sentiment(self, sentence):
|
||||
return self.run("@cf/huggingface/distilbert-sst-2-int8", {"text": sentence})
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/image-classification/
|
||||
def run_image_classification(self, image_url):
|
||||
response = requests.get(image_url)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {"error": "There was a problem downloading the image."}
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.content
|
||||
inputs = {"image": list(data)}
|
||||
|
||||
return self.run("@cf/microsoft/resnet-50", inputs)
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/embedding/
|
||||
def run_embeddings(self, texts, size="medium"):
|
||||
models = {
|
||||
"small": "@cf/baai/bge-small-en-v1.5", # 384 output dimensions
|
||||
"medium": "@cf/baai/bge-base-en-v1.5", # 768 output dimensions
|
||||
"large": "@cf/baai/bge-large-en-v1.5", # 1024 output dimensions
|
||||
}
|
||||
|
||||
return self.run(models[size], {"text": texts})
|
||||
@@ -1,30 +0,0 @@
|
||||
import os
|
||||
|
||||
import openai
|
||||
|
||||
# To use Google Cloud's AI products, you'll need to install Google Cloud
|
||||
# CLI and enable the TTS and in your project:
|
||||
# https://cloud.google.com/sdk/docs/install
|
||||
from google.cloud import texttospeech
|
||||
from services.ai_service import AIService
|
||||
|
||||
|
||||
class GoogleAIService(AIService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.client = texttospeech.TextToSpeechClient()
|
||||
self.voice = texttospeech.VoiceSelectionParams(
|
||||
language_code="en-GB", name="en-GB-Neural2-F"
|
||||
)
|
||||
|
||||
self.audio_config = texttospeech.AudioConfig(
|
||||
audio_encoding=texttospeech.AudioEncoding.LINEAR16, sample_rate_hertz=16000
|
||||
)
|
||||
|
||||
def run_tts(self, sentence):
|
||||
synthesis_input = texttospeech.SynthesisInput(text=sentence.strip())
|
||||
result = self.client.synthesize_speech(
|
||||
input=synthesis_input, voice=self.voice, audio_config=self.audio_config
|
||||
)
|
||||
return result
|
||||
@@ -1,33 +0,0 @@
|
||||
from services.ai_service import AIService
|
||||
from transformers import pipeline
|
||||
|
||||
# These functions are just intended for testing, not production use. If
|
||||
# you'd like to use HuggingFace, you should use your own models, or do
|
||||
# some research into the specific models that will work best for your use
|
||||
# case.
|
||||
|
||||
|
||||
class HuggingFaceAIService(AIService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def run_text_sentiment(self, sentence):
|
||||
classifier = pipeline("sentiment-analysis")
|
||||
return classifier(sentence)
|
||||
|
||||
# available models at https://huggingface.co/Helsinki-NLP (**not all
|
||||
# models use 2-character language codes**)
|
||||
def run_text_translation(self, sentence, source_language, target_language):
|
||||
translator = pipeline(
|
||||
f"translation", model=f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"
|
||||
)
|
||||
|
||||
return translator(sentence)[0]["translation_text"]
|
||||
|
||||
def run_text_summarization(self, sentence):
|
||||
summarizer = pipeline("summarization")
|
||||
return summarizer(sentence)
|
||||
|
||||
def run_image_classification(self, image_path):
|
||||
classifier = pipeline("image-classification")
|
||||
return classifier(image_path)
|
||||
@@ -1,28 +0,0 @@
|
||||
import io
|
||||
import time
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from services.ai_service import AIService
|
||||
|
||||
|
||||
class MockAIService(AIService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def run_tts(self, sentence):
|
||||
print("running tts", sentence)
|
||||
time.sleep(2)
|
||||
|
||||
def run_image_gen(self, sentence):
|
||||
image_url = "https://d3d00swyhr67nd.cloudfront.net/w800h800/collection/ASH/ASHM/ASH_ASHM_WA1940_2_22-001.jpg"
|
||||
response = requests.get(image_url)
|
||||
image_stream = io.BytesIO(response.content)
|
||||
image = Image.open(image_stream)
|
||||
time.sleep(1)
|
||||
return (image_url, image.tobytes(), image.size)
|
||||
|
||||
def run_llm(self, messages, latest_user_message=None, stream=True):
|
||||
for i in range(5):
|
||||
time.sleep(1)
|
||||
yield ({"choices": [{"delta": {"content": f"hello {i}!"}}]})
|
||||
13
src/pipecat/services/together/__init__.py
Normal file
13
src/pipecat/services/together/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "together", "together.llm")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user