Compare commits

...

2 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
8d6b8b035e Merge pull request #332 from pipecat-ai/aleix/allow-internal-http-sessions
services: allow internal http sessions if none is given
2024-07-31 15:51:52 -07:00
Aleix Conchillo Flaqué
0a15874c12 services: allow internal http sessions if none is given 2024-07-30 17:44:18 -07:00
6 changed files with 49 additions and 13 deletions

View File

@@ -171,12 +171,12 @@ class AzureImageGenServiceREST(ImageGenService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
image_size: str,
api_key: str,
endpoint: str,
model: str,
api_version="2023-06-01-preview",
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
@@ -184,8 +184,14 @@ class AzureImageGenServiceREST(ImageGenService):
self._azure_endpoint = endpoint
self._api_version = api_version
self._model = model
self._aiohttp_session = aiohttp_session
self._image_size = image_size
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"

View File

@@ -45,25 +45,31 @@ class DeepgramTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice: str = "aura-helios-en",
base_url: str = "https://api.deepgram.com/v1/speak",
sample_rate: int = 16000,
encoding: str = "linear16",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._voice = voice
self._api_key = api_key
self._aiohttp_session = aiohttp_session
self._base_url = base_url
self._sample_rate = sample_rate
self._encoding = encoding
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice

View File

@@ -19,21 +19,27 @@ class ElevenLabsTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice_id: str,
model: str = "eleven_turbo_v2",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._api_key = api_key
self._voice_id = voice_id
self._aiohttp_session = aiohttp_session
self._model = model
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

View File

@@ -39,18 +39,24 @@ class FalImageGenService(ImageGenService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
params: InputParams,
model: str = "fal-ai/fast-sdxl",
key: str | None = None,
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._params = params
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
if key:
os.environ["FAL_KEY"] = key
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")

View File

@@ -253,16 +253,22 @@ class OpenAIImageGenService(ImageGenService):
def __init__(
self,
*,
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession,
api_key: str,
model: str = "dall-e-3",
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._image_size = image_size
self._client = AsyncOpenAI(api_key=api_key)
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")

View File

@@ -38,22 +38,28 @@ class XTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
voice_id: str,
language: str,
base_url: str,
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)
self._voice_id = voice_id
self._language = language
self._base_url = base_url
self._aiohttp_session = aiohttp_session
self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json()
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
def can_generate_metrics(self) -> bool:
return True
async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice