fal.ai integration (#3)

* fal.ai image gen

* some sample and readme updates

* holy cow this is fast

* basic image-gen working

* starting audio prompt and reset

* short confirmation words

* moved fal module to pyproject.toml

---------

Co-authored-by: Moishe Lettvin <moishel@gmail.com>
This commit is contained in:
chadbailey59
2024-01-17 12:08:00 -06:00
committed by GitHub
parent b1e6a89acd
commit 6e8ebbd34c
5 changed files with 176 additions and 5 deletions

View File

@@ -15,7 +15,8 @@ dependencies = [
"azure-cognitiveservices-speech",
"pyht",
"opentelemetry-sdk",
"aiohttp"
"aiohttp",
"fal"
]
[tool.setuptools.packages.find]

View File

@@ -1,2 +1,2 @@
Pillow==10.1.0
typing_extensions==4.9.0
typing_extensions==4.9.0

View File

@@ -0,0 +1,51 @@
import fal
import aiohttp
import asyncio
import io
import json
from PIL import Image
from dailyai.services.ai_services import LLMService, TTSService, ImageGenService
# Fal expects FAL_KEY_ID and FAL_KEY_SECRET to be set in the env
class FalImageGenService(ImageGenService):
def __init__(self):
super().__init__()
async def run_image_gen(self, sentence, size) -> tuple[str, bytes]:
def get_image_url(sentence, size):
print("starting fal submit...")
handler = fal.apps.submit(
"110602490-fast-sdxl",
arguments={
"prompt": sentence
},
)
print("past fal handler init, about to wait for iter_events...")
for event in handler.iter_events():
if isinstance(event, fal.apps.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
image_url = result["images"][0]["url"] if result else None
if not image_url:
raise Exception("Image generation failed")
return image_url
print(f"fetching image url...")
image_url = await asyncio.to_thread(get_image_url, sentence, size)
print(f"got image url, downloading image...")
# Load the image from the url
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
print("got image response")
image_stream = io.BytesIO(await response.content.read())
print("read image stream")
image = Image.open(image_stream)
return (image_url, image.tobytes())
# return (image_url, dalle_im.tobytes())

113
src/samples/image-gen.py Normal file
View File

@@ -0,0 +1,113 @@
import argparse
import asyncio
import requests
import time
import urllib.parse
import random
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.services.fal_ai_services import FalImageGenService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
async def main(room_url:str, token):
global transport
global llm
global tts
transport = DailyTransportService(
room_url,
token,
"Imagebot",
1,
)
transport.mic_enabled = True
transport.camera_enabled = True
transport.mic_sample_rate = 16000
transport.camera_width = 1024
transport.camera_height = 1024
llm = AzureLLMService()
tts = AzureTTSService()
img = FalImageGenService()
async def handle_transcriptions():
print("handle_transcriptions got called")
sentence = ""
async for message in transport.get_transcriptions():
print(f"transcription message: {message}")
if message["session_id"] == transport.my_participant_id:
continue
finder = message["text"].find("start over")
print(f"finder: {finder}")
if finder >= 0:
async for audio in tts.run_tts(f"Resetting."):
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
sentence = ""
continue
# todo: we could differentiate between transcriptions from different participants
sentence += f" {message['text']}"
print(f"sentence is now: {sentence}")
# TODO: Cache this audio
phrase = random.choice(["OK.", "Got it.", "Sure.", "You bet.", "Sure thing."])
async for audio in tts.run_tts(phrase):
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
img_result = img.run_image_gen(sentence, "1024x1024")
awaited_img = await asyncio.gather(img_result)
transport.output_queue.put(
[
QueueFrame(FrameType.IMAGE_FRAME, awaited_img[0][1]),
]
)
@transport.event_handler("on_participant_joined")
async def on_participant_joined(transport, participant):
print(f"participant joined: {participant['info']['userName']}")
if participant["info"]["isLocal"]:
return
async for audio in tts.run_tts("Describe an image, and I'll create it."):
audio_generator = tts.run_tts(f"Hello, {participant['info']['userName']}! Describe an image and I'll create it. To start over, just say 'start over'.")
async for audio in audio_generator:
transport.output_queue.put(QueueFrame(FrameType.AUDIO_FRAME, audio))
transport.transcription_settings["extra"]["punctuate"] = False
transport.transcription_settings["extra"]["endpointing"] = False
await asyncio.gather(transport.run(), handle_transcriptions())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
parser.add_argument(
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
)
parser.add_argument(
"-k",
"--apikey",
type=str,
required=True,
help="Daily API Key (needed to create token)",
)
args, unknown = parser.parse_known_args()
# Create a meeting token for the given room with an expiration 1 hour in the future.
room_name: str = urllib.parse.urlparse(args.url).path[1:]
expiration: float = time.time() + 60 * 60
res: requests.Response = requests.post(
f"https://api.daily.co/v1/meeting-tokens",
headers={"Authorization": f"Bearer {args.apikey}"},
json={
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
},
)
if res.status_code != 200:
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
token: str = res.json()["token"]
asyncio.run(main(args.url, token))

View File

@@ -9,6 +9,7 @@ from dailyai.services.azure_ai_services import AzureLLMService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.open_ai_services import OpenAIImageGenService
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.fal_ai_services import FalImageGenService
async def main(room_url):
meeting_duration_minutes = 5
@@ -25,8 +26,10 @@ async def main(room_url):
transport.camera_height = 1024
llm = AzureLLMService()
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
dalle = OpenAIImageGenService()
#tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
tts = ElevenLabsTTSService()
dalle = FalImageGenService()
# dalle = OpenAIImageGenService()
# Get a complete audio chunk from the given text. Splitting this into its own
# coroutine lets us ensure proper ordering of the audio chunks on the output queue.
@@ -77,7 +80,8 @@ async def main(room_url):
months: list[str] = [
"January",
"February",
"March",
"March"]
"""
"April",
"May",
"June",
@@ -88,6 +92,7 @@ async def main(room_url):
"November",
"December",
]
"""
@transport.event_handler("on_first_other_participant_joined")
async def on_first_other_participant_joined(transport):
@@ -96,6 +101,7 @@ async def main(room_url):
# likely no delay between months, but the months won't display in order.
for month_data_task in asyncio.as_completed(month_tasks):
data = await month_data_task
print(f"got data, queueing frames...")
transport.output_queue.put(
[
QueueFrame(FrameType.IMAGE, data["image"]),