fal.ai integration (#3)
* fal.ai image gen * some sample and readme updates * holy cow this is fast * basic image-gen working * starting audio prompt and reset * short confirmation words * moved fal module to pyproject.toml --------- Co-authored-by: Moishe Lettvin <moishel@gmail.com>
This commit is contained in:
@@ -15,7 +15,8 @@ dependencies = [
|
||||
"azure-cognitiveservices-speech",
|
||||
"pyht",
|
||||
"opentelemetry-sdk",
|
||||
"aiohttp"
|
||||
"aiohttp",
|
||||
"fal"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
Pillow==10.1.0
|
||||
typing_extensions==4.9.0
|
||||
typing_extensions==4.9.0
|
||||
51
src/dailyai/services/fal_ai_services.py
Normal file
51
src/dailyai/services/fal_ai_services.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import fal
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from dailyai.services.ai_services import LLMService, TTSService, ImageGenService
|
||||
# Fal expects FAL_KEY_ID and FAL_KEY_SECRET to be set in the env
|
||||
class FalImageGenService(ImageGenService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
|
||||
async def run_image_gen(self, sentence, size) -> tuple[str, bytes]:
|
||||
def get_image_url(sentence, size):
|
||||
print("starting fal submit...")
|
||||
handler = fal.apps.submit(
|
||||
"110602490-fast-sdxl",
|
||||
arguments={
|
||||
"prompt": sentence
|
||||
},
|
||||
)
|
||||
print("past fal handler init, about to wait for iter_events...")
|
||||
for event in handler.iter_events():
|
||||
if isinstance(event, fal.apps.InProgress):
|
||||
print('Request in progress')
|
||||
print(event.logs)
|
||||
|
||||
result = handler.get()
|
||||
|
||||
image_url = result["images"][0]["url"] if result else None
|
||||
if not image_url:
|
||||
raise Exception("Image generation failed")
|
||||
|
||||
return image_url
|
||||
print(f"fetching image url...")
|
||||
image_url = await asyncio.to_thread(get_image_url, sentence, size)
|
||||
print(f"got image url, downloading image...")
|
||||
# Load the image from the url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
print("got image response")
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
print("read image stream")
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
# return (image_url, dalle_im.tobytes())
|
||||
113
src/samples/image-gen.py
Normal file
113
src/samples/image-gen.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import requests
|
||||
import time
|
||||
import urllib.parse
|
||||
import random
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.fal_ai_services import FalImageGenService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
async def main(room_url:str, token):
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Imagebot",
|
||||
1,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.camera_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
transport.camera_width = 1024
|
||||
transport.camera_height = 1024
|
||||
|
||||
llm = AzureLLMService()
|
||||
tts = AzureTTSService()
|
||||
img = FalImageGenService()
|
||||
|
||||
|
||||
async def handle_transcriptions():
|
||||
print("handle_transcriptions got called")
|
||||
|
||||
sentence = ""
|
||||
async for message in transport.get_transcriptions():
|
||||
print(f"transcription message: {message}")
|
||||
if message["session_id"] == transport.my_participant_id:
|
||||
continue
|
||||
finder = message["text"].find("start over")
|
||||
print(f"finder: {finder}")
|
||||
if finder >= 0:
|
||||
async for audio in tts.run_tts(f"Resetting."):
|
||||
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
|
||||
sentence = ""
|
||||
continue
|
||||
# todo: we could differentiate between transcriptions from different participants
|
||||
sentence += f" {message['text']}"
|
||||
print(f"sentence is now: {sentence}")
|
||||
# TODO: Cache this audio
|
||||
phrase = random.choice(["OK.", "Got it.", "Sure.", "You bet.", "Sure thing."])
|
||||
async for audio in tts.run_tts(phrase):
|
||||
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
|
||||
img_result = img.run_image_gen(sentence, "1024x1024")
|
||||
awaited_img = await asyncio.gather(img_result)
|
||||
transport.output_queue.put(
|
||||
[
|
||||
QueueFrame(FrameType.IMAGE_FRAME, awaited_img[0][1]),
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_participant_joined")
|
||||
async def on_participant_joined(transport, participant):
|
||||
print(f"participant joined: {participant['info']['userName']}")
|
||||
if participant["info"]["isLocal"]:
|
||||
return
|
||||
async for audio in tts.run_tts("Describe an image, and I'll create it."):
|
||||
audio_generator = tts.run_tts(f"Hello, {participant['info']['userName']}! Describe an image and I'll create it. To start over, just say 'start over'.")
|
||||
async for audio in audio_generator:
|
||||
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = False
|
||||
transport.transcription_settings["extra"]["endpointing"] = False
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Daily API Key (needed to create token)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in the future.
|
||||
room_name: str = urllib.parse.urlparse(args.url).path[1:]
|
||||
expiration: float = time.time() + 60 * 60
|
||||
|
||||
res: requests.Response = requests.post(
|
||||
f"https://api.daily.co/v1/meeting-tokens",
|
||||
headers={"Authorization": f"Bearer {args.apikey}"},
|
||||
json={
|
||||
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
|
||||
},
|
||||
)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
|
||||
|
||||
token: str = res.json()["token"]
|
||||
|
||||
asyncio.run(main(args.url, token))
|
||||
@@ -9,6 +9,7 @@ from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.services.open_ai_services import OpenAIImageGenService
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.fal_ai_services import FalImageGenService
|
||||
|
||||
async def main(room_url):
|
||||
meeting_duration_minutes = 5
|
||||
@@ -25,8 +26,10 @@ async def main(room_url):
|
||||
transport.camera_height = 1024
|
||||
|
||||
llm = AzureLLMService()
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
dalle = OpenAIImageGenService()
|
||||
#tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
tts = ElevenLabsTTSService()
|
||||
dalle = FalImageGenService()
|
||||
# dalle = OpenAIImageGenService()
|
||||
|
||||
# Get a complete audio chunk from the given text. Splitting this into its own
|
||||
# coroutine lets us ensure proper ordering of the audio chunks on the output queue.
|
||||
@@ -77,7 +80,8 @@ async def main(room_url):
|
||||
months: list[str] = [
|
||||
"January",
|
||||
"February",
|
||||
"March",
|
||||
"March"]
|
||||
"""
|
||||
"April",
|
||||
"May",
|
||||
"June",
|
||||
@@ -88,6 +92,7 @@ async def main(room_url):
|
||||
"November",
|
||||
"December",
|
||||
]
|
||||
"""
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
@@ -96,6 +101,7 @@ async def main(room_url):
|
||||
# likely no delay between months, but the months won't display in order.
|
||||
for month_data_task in asyncio.as_completed(month_tasks):
|
||||
data = await month_data_task
|
||||
print(f"got data, queueing frames...")
|
||||
transport.output_queue.put(
|
||||
[
|
||||
QueueFrame(FrameType.IMAGE, data["image"]),
|
||||
|
||||
Reference in New Issue
Block a user