Compare commits
13 Commits
hush/callT
...
cb/valoran
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06b03dcc33 | ||
|
|
af4ab95713 | ||
|
|
a8d618ede1 | ||
|
|
d6108dae5c | ||
|
|
1f9c5d132f | ||
|
|
3794e86868 | ||
|
|
5a47a3d5cd | ||
|
|
0cae54e79e | ||
|
|
aee5087a46 | ||
|
|
4548f91fdc | ||
|
|
33ea1f9925 | ||
|
|
fd5ff5fee5 | ||
|
|
237db19c40 |
@@ -3,7 +3,7 @@ import copy
|
||||
import functools
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame, UserStartedSpeakingFrame
|
||||
|
||||
|
||||
class InterruptibleConversationWrapper:
|
||||
@@ -63,9 +63,10 @@ class InterruptibleConversationWrapper:
|
||||
if frame.participantId == self._my_participant_id:
|
||||
continue
|
||||
|
||||
if current_response_task:
|
||||
if current_response_task and isinstance(frame, UserStartedSpeakingFrame):
|
||||
current_response_task.cancel()
|
||||
self._interrupt()
|
||||
|
||||
|
||||
self._current_phrase += " " + frame.text
|
||||
current_llm_messages = copy.deepcopy(self._messages)
|
||||
|
||||
@@ -52,7 +52,7 @@ class LLMContextAggregator(AIService):
|
||||
if isinstance(frame, TranscriptionQueueFrame):
|
||||
if frame.participantId == self.bot_participant_id:
|
||||
return
|
||||
|
||||
print(f"@@@ tma got a frame: {frame.text}")
|
||||
# The common case for "pass through" is receiving frames from the LLM that we'll
|
||||
# use to update the "assistant" LLM messages, but also passing the text frames
|
||||
# along to a TTS service to be spoken to the user.
|
||||
@@ -65,8 +65,11 @@ class LLMContextAggregator(AIService):
|
||||
# though we check it above
|
||||
self.sentence += frame.text
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.messages.append(
|
||||
{"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
# for message in self.messages:
|
||||
# print(f"{message['role']}: {message['content']}")
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
else:
|
||||
# type: ignore -- the linter thinks this isn't a TextQueueFrame, even
|
||||
@@ -78,6 +81,8 @@ class LLMContextAggregator(AIService):
|
||||
# Send any dangling words that weren't finished with punctuation.
|
||||
if self.complete_sentences and self.sentence:
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
# for message in self.messages:
|
||||
# print(f"{message['role']}: {message['content']}")
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,17 @@ class LLMResponseEndQueueFrame(QueueFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ChatMessageQueueFrame(QueueFrame):
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass()
|
||||
class LLMFunctionCallFrame(QueueFrame):
|
||||
function_name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass()
|
||||
class AudioQueueFrame(QueueFrame):
|
||||
data: bytes
|
||||
|
||||
@@ -11,6 +11,8 @@ from dailyai.queue_frame import (
|
||||
ImageQueueFrame,
|
||||
LLMMessagesQueueFrame,
|
||||
LLMResponseEndQueueFrame,
|
||||
LLMFunctionCallFrame,
|
||||
ChatMessageQueueFrame,
|
||||
QueueFrame,
|
||||
TextQueueFrame,
|
||||
TranscriptionQueueFrame,
|
||||
@@ -41,7 +43,7 @@ class AIService:
|
||||
frames: Iterable[QueueFrame]
|
||||
| AsyncIterable[QueueFrame]
|
||||
| asyncio.Queue[QueueFrame],
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
**kwargs) -> AsyncGenerator[QueueFrame, None]:
|
||||
try:
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
@@ -88,10 +90,25 @@ class LLMService(AIService):
|
||||
async def run_llm(self, messages) -> str:
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
async def process_frame(self, frame: QueueFrame, tool_choice: str = None) -> AsyncGenerator[QueueFrame, None]:
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
if isinstance(frame, LLMMessagesQueueFrame):
|
||||
async for text_chunk in self.run_llm_async(frame.messages):
|
||||
yield TextQueueFrame(text_chunk)
|
||||
async for text_chunk in self.run_llm_async(frame.messages, tool_choice):
|
||||
if isinstance(text_chunk, str):
|
||||
yield TextQueueFrame(text_chunk)
|
||||
elif text_chunk.function:
|
||||
if text_chunk.function.name:
|
||||
# function_name += text_chunk.function.name
|
||||
yield LLMFunctionCallFrame(function_name=text_chunk.function.name, arguments=None)
|
||||
if text_chunk.function.arguments:
|
||||
# arguments += text_chunk.function.arguments
|
||||
yield LLMFunctionCallFrame(function_name=None, arguments=text_chunk.function.arguments)
|
||||
|
||||
if (function_name and arguments):
|
||||
# yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
yield LLMResponseEndQueueFrame()
|
||||
else:
|
||||
yield frame
|
||||
@@ -129,6 +146,7 @@ class TTSService(AIService):
|
||||
self.current_sentence = ""
|
||||
|
||||
if text:
|
||||
# yield ChatMessageQueueFrame(message=text)
|
||||
async for audio_chunk in self.run_tts(text):
|
||||
yield AudioQueueFrame(audio_chunk)
|
||||
|
||||
@@ -202,6 +220,6 @@ class FrameLogger(AIService):
|
||||
if isinstance(frame, (AudioQueueFrame, ImageQueueFrame)):
|
||||
self.logger.info(f"{self.prefix}: {type(frame)}")
|
||||
else:
|
||||
print(f"{self.prefix}: {frame}")
|
||||
self.logger.info(f"{self.prefix}: {frame}")
|
||||
|
||||
yield frame
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -11,9 +12,11 @@ import threading
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, AsyncIterable, BinaryIO, Iterable
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
ChatMessageQueueFrame,
|
||||
EndStreamQueueFrame,
|
||||
ImageQueueFrame,
|
||||
QueueFrame,
|
||||
@@ -23,6 +26,7 @@ from dailyai.queue_frame import (
|
||||
UserStoppedSpeakingFrame
|
||||
)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
||||
@@ -70,7 +74,6 @@ class VADState(Enum):
|
||||
SPEAKING = 3
|
||||
STOPPING = 4
|
||||
|
||||
|
||||
class BaseTransportService():
|
||||
|
||||
def __init__(
|
||||
@@ -121,20 +124,68 @@ class BaseTransportService():
|
||||
self._is_interrupted = threading.Event()
|
||||
|
||||
self._logger: logging.Logger = logging.getLogger()
|
||||
|
||||
def update_messages(self, new_context: list[dict[str, str]], task: asyncio.Task | None):
|
||||
if task:
|
||||
if not task.cancelled():
|
||||
self._current_phrase = ""
|
||||
self._context = new_context
|
||||
|
||||
|
||||
|
||||
async def run_pipeline(self, frame):
|
||||
# TODO-CB: This exception for missing class gets eaten!
|
||||
await self._runner(frame)
|
||||
|
||||
async def run_conversation(self, runner: Iterable[QueueFrame]
|
||||
| AsyncIterable[QueueFrame]
|
||||
| asyncio.Queue[QueueFrame],
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
current_response_task = None
|
||||
self._runner = runner
|
||||
|
||||
async for frame in self.get_receive_frames():
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
# elif not isinstance(frame, TranscriptionQueueFrame):
|
||||
# continue
|
||||
# TODO-CB: Verify this is an accurate replacement
|
||||
# if hasattr(frame, 'participantId') and frame.participantId == self._my_participant_id:
|
||||
# if not isinstance(frame, UserStoppedSpeakingFrame):
|
||||
# continue
|
||||
|
||||
if current_response_task and isinstance(frame, UserStartedSpeakingFrame):
|
||||
# TODO-CB: Maybe not always interrupt? Are there frame types we can pass through?
|
||||
current_response_task.cancel()
|
||||
self.interrupt()
|
||||
|
||||
# self._current_phrase += " " + frame.text
|
||||
# current_llm_context = copy.deepcopy(self._context)
|
||||
current_response_task = asyncio.create_task(
|
||||
self.run_pipeline(
|
||||
frame)
|
||||
)
|
||||
current_response_task.add_done_callback(
|
||||
functools.partial(self.update_messages, self._context)
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self._prerun()
|
||||
|
||||
async_output_queue_marshal_task = asyncio.create_task(self._marshal_frames())
|
||||
async_output_queue_marshal_task = asyncio.create_task(
|
||||
self._marshal_frames())
|
||||
|
||||
self._camera_thread = threading.Thread(target=self._run_camera, daemon=True)
|
||||
self._camera_thread = threading.Thread(
|
||||
target=self._run_camera, daemon=True)
|
||||
self._camera_thread.start()
|
||||
|
||||
self._frame_consumer_thread = threading.Thread(target=self._frame_consumer, daemon=True)
|
||||
self._frame_consumer_thread = threading.Thread(
|
||||
target=self._frame_consumer, daemon=True)
|
||||
self._frame_consumer_thread.start()
|
||||
|
||||
if self._speaker_enabled:
|
||||
self._receive_audio_thread = threading.Thread(target=self._receive_audio, daemon=True)
|
||||
self._receive_audio_thread = threading.Thread(
|
||||
target=self._receive_audio, daemon=True)
|
||||
self._receive_audio_thread.start()
|
||||
|
||||
if self._vad_enabled:
|
||||
@@ -212,7 +263,6 @@ class BaseTransportService():
|
||||
new_confidence = model(
|
||||
torch.from_numpy(audio_float32), 16000).item()
|
||||
speaking = new_confidence > 0.5
|
||||
|
||||
if speaking:
|
||||
match self._vad_state:
|
||||
case VADState.QUIET:
|
||||
@@ -235,14 +285,16 @@ class BaseTransportService():
|
||||
self._vad_stopping_count += 1
|
||||
|
||||
if self._vad_state == VADState.STARTING and self._vad_starting_count >= self._vad_start_frames:
|
||||
print("##### VAD START")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(
|
||||
UserStartedSpeakingFrame()), self._loop
|
||||
)
|
||||
# self.interrupt()
|
||||
self.interrupt()
|
||||
self._vad_state = VADState.SPEAKING
|
||||
self._vad_starting_count = 0
|
||||
if self._vad_state == VADState.STOPPING and self._vad_stopping_count >= self._vad_stop_frames:
|
||||
print("##### VAD STOP")
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(
|
||||
UserStoppedSpeakingFrame()), self._loop
|
||||
@@ -259,6 +311,7 @@ class BaseTransportService():
|
||||
break
|
||||
|
||||
def interrupt(self):
|
||||
print(f"!!!!! INTERRUPT")
|
||||
self._is_interrupted.set()
|
||||
|
||||
async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
@@ -282,7 +335,6 @@ class BaseTransportService():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(frame), self._loop
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(EndStreamQueueFrame()), self._loop
|
||||
)
|
||||
@@ -309,13 +361,19 @@ class BaseTransportService():
|
||||
self._logger.info("🎬 Starting frame consumer thread")
|
||||
b = bytearray()
|
||||
smallest_write_size = 3200
|
||||
largest_write_size = 8000
|
||||
all_audio_frames = bytearray()
|
||||
while True:
|
||||
try:
|
||||
frames_or_frame: QueueFrame | list[QueueFrame] = (
|
||||
self._threadsafe_send_queue.get()
|
||||
)
|
||||
if isinstance(frames_or_frame, QueueFrame):
|
||||
if isinstance(frames_or_frame, AudioQueueFrame) and len(frames_or_frame.data) > largest_write_size:
|
||||
# subdivide large audio frames to enable interruption
|
||||
frames = []
|
||||
for i in range(0, len(frames_or_frame.data), largest_write_size):
|
||||
frames.append(AudioQueueFrame(frames_or_frame.data[i : i+largest_write_size]))
|
||||
elif isinstance(frames_or_frame, QueueFrame):
|
||||
frames: list[QueueFrame] = [frames_or_frame]
|
||||
elif isinstance(frames_or_frame, list):
|
||||
frames: list[QueueFrame] = frames_or_frame
|
||||
@@ -341,12 +399,15 @@ class BaseTransportService():
|
||||
len(b) % smallest_write_size
|
||||
)
|
||||
if truncated_length:
|
||||
self.write_frame_to_mic(bytes(b[:truncated_length]))
|
||||
self.write_frame_to_mic(
|
||||
bytes(b[:truncated_length]))
|
||||
b = b[truncated_length:]
|
||||
elif isinstance(frame, ImageQueueFrame):
|
||||
self._set_image(frame.image)
|
||||
elif isinstance(frame, SpriteQueueFrame):
|
||||
self._set_images(frame.images)
|
||||
elif isinstance(frame, ChatMessageQueueFrame):
|
||||
self._send_chat_message(frame)
|
||||
elif len(b):
|
||||
self.write_frame_to_mic(bytes(b))
|
||||
b = bytearray()
|
||||
@@ -355,7 +416,8 @@ class BaseTransportService():
|
||||
# can cause static in the audio stream.
|
||||
if len(b):
|
||||
truncated_length = len(b) - (len(b) % 160)
|
||||
self.write_frame_to_mic(bytes(b[:truncated_length]))
|
||||
self.write_frame_to_mic(
|
||||
bytes(b[:truncated_length]))
|
||||
b = bytearray()
|
||||
|
||||
if isinstance(frame, StartStreamQueueFrame):
|
||||
@@ -368,5 +430,6 @@ class BaseTransportService():
|
||||
|
||||
b = bytearray()
|
||||
except Exception as e:
|
||||
self._logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
print(
|
||||
f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
raise e
|
||||
|
||||
@@ -81,7 +81,8 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
for handler in self._event_handlers[event_name]:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
if self._loop:
|
||||
asyncio.run_coroutine_threadsafe(handler(*args, **kwargs), self._loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
handler(*args, **kwargs), self._loop)
|
||||
else:
|
||||
raise Exception(
|
||||
"No event loop to run coroutine. In order to use async event handlers, you must run the DailyTransportService in an asyncio event loop.")
|
||||
@@ -93,7 +94,8 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if not event_name.startswith("on_"):
|
||||
raise Exception(f"Event handler {event_name} must start with 'on_'")
|
||||
raise Exception(
|
||||
f"Event handler {event_name} must start with 'on_'")
|
||||
|
||||
methods = inspect.getmembers(self, predicate=inspect.ismethod)
|
||||
if event_name not in [method[0] for method in methods]:
|
||||
@@ -106,7 +108,8 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
handler, self)]
|
||||
setattr(self, event_name, partial(self._patch_method, event_name))
|
||||
else:
|
||||
self._event_handlers[event_name].append(types.MethodType(handler, self))
|
||||
self._event_handlers[event_name].append(
|
||||
types.MethodType(handler, self))
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
def decorator(handler):
|
||||
@@ -150,6 +153,7 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self._bot_name)
|
||||
|
||||
self.client.join(
|
||||
self._room_url,
|
||||
self._token,
|
||||
@@ -235,11 +239,19 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
self._other_participant_has_joined = True
|
||||
self.on_first_other_participant_joined()
|
||||
|
||||
"""
|
||||
def on_participant_left(self, participant, reason):
|
||||
if len(self.client.participants()) < self._min_others_count + 1:
|
||||
self._stop_threads.set()
|
||||
"""
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
print(f"app message: {message}")
|
||||
if self._loop:
|
||||
frame = TranscriptionQueueFrame(
|
||||
message["message"], message["name"], message["date"])
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(frame), self._loop)
|
||||
pass
|
||||
|
||||
def on_transcription_message(self, message: dict):
|
||||
@@ -249,8 +261,10 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
participantId = message["participantId"]
|
||||
elif "session_id" in message:
|
||||
participantId = message["session_id"]
|
||||
frame = TranscriptionQueueFrame(message["text"], participantId, message["timestamp"])
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self._loop)
|
||||
frame = TranscriptionQueueFrame(
|
||||
message["text"], participantId, message["timestamp"])
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_queue.put(frame), self._loop)
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
pass
|
||||
@@ -260,3 +274,11 @@ class DailyTransportService(BaseTransportService, EventHandler):
|
||||
|
||||
def on_transcription_started(self, status):
|
||||
pass
|
||||
|
||||
def _send_chat_message(self, frame):
|
||||
self.client.send_app_message(
|
||||
{'message': frame.message, 'event': 'chat-msg', 'name': self._bot_name, 'date': time.time(), 'room': 'main-room'})
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
self.client.leave()
|
||||
|
||||
@@ -26,7 +26,8 @@ class ElevenLabsTTSService(TTSService):
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self._voice_id}/stream"
|
||||
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
|
||||
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
|
||||
querystring = {"output_format": "pcm_16000",
|
||||
"optimize_streaming_latency": 2}
|
||||
headers = {
|
||||
"xi-api-key": self._api_key,
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -10,28 +10,39 @@ from dailyai.services.ai_services import LLMService, ImageGenService
|
||||
|
||||
|
||||
class OpenAILLMService(LLMService):
|
||||
def __init__(self, *, api_key, model="gpt-4"):
|
||||
def __init__(self, *, api_key, model="gpt-4", tools=None):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._tools = tools
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
async def get_response(self, messages, stream):
|
||||
return await self._client.chat.completions.create(
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
model=self._model
|
||||
model=self._model,
|
||||
tools=self._tools
|
||||
)
|
||||
|
||||
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
|
||||
async def run_llm_async(self, messages, tool_choice=None) -> AsyncGenerator[str, None]:
|
||||
messages_for_log = json.dumps(messages)
|
||||
self.logger.debug(f"Generating chat via openai: {messages_for_log}")
|
||||
|
||||
chunks = await self._client.chat.completions.create(model=self._model, stream=True, messages=messages)
|
||||
print("---")
|
||||
print(f"tools: {self._tools}")
|
||||
print("---")
|
||||
print(f"messages: {messages_for_log}")
|
||||
print("-----")
|
||||
if self._tools:
|
||||
tools = self._tools
|
||||
else:
|
||||
tools = None
|
||||
chunks = await self._client.chat.completions.create(model=self._model, stream=True, messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
async for chunk in chunks:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.content:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
yield chunk.choices[0].delta.tool_calls[0]
|
||||
elif chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def run_llm(self, messages) -> str | None:
|
||||
|
||||
@@ -17,7 +17,8 @@ class CloudflareAIService(AIService):
|
||||
|
||||
# base endpoint, used by the others
|
||||
def run(self, model, input):
|
||||
response = requests.post(f"{self.api_base_url}{model}", headers=self.headers, json=input)
|
||||
response = requests.post(
|
||||
f"{self.api_base_url}{model}", headers=self.headers, json=input)
|
||||
return response.json()
|
||||
|
||||
# https://developers.cloudflare.com/workers-ai/models/llm/
|
||||
|
||||
@@ -1,68 +1,97 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.services.ai_services import FrameLogger
|
||||
from dailyai.services.open_ai_services import OpenAILLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
|
||||
from support.runner import configure
|
||||
from examples.foundational.support.runner import configure
|
||||
from dailyai.queue_frame import LLMMessagesQueueFrame, TranscriptionQueueFrame, QueueFrame, TextQueueFrame
|
||||
from dailyai.services.ai_services import FrameLogger, AIService
|
||||
|
||||
class TranscriptFilter(AIService):
|
||||
def __init__(self, bot_participant_id=None):
|
||||
super().__init__()
|
||||
self.bot_participant_id = bot_participant_id
|
||||
print(f"Filtering transcripts from : {self.bot_participant_id}")
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, TranscriptionQueueFrame):
|
||||
if frame.participantId != self.bot_participant_id:
|
||||
yield frame
|
||||
|
||||
async def main(room_url: str, token):
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
duration_minutes=5,
|
||||
start_transcription=True,
|
||||
mic_enabled=True,
|
||||
mic_sample_rate=16000,
|
||||
camera_enabled=False,
|
||||
vad_enabled=True
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
llm = AzureLLMService(
|
||||
api_key=os.getenv("AZURE_CHATGPT_API_KEY"),
|
||||
endpoint=os.getenv("AZURE_CHATGPT_ENDPOINT"),
|
||||
model=os.getenv("AZURE_CHATGPT_MODEL"))
|
||||
tts = AzureTTSService(
|
||||
api_key=os.getenv("AZURE_SPEECH_API_KEY"),
|
||||
region=os.getenv("AZURE_SPEECH_REGION"))
|
||||
fl = FrameLogger("Inner")
|
||||
fl2 = FrameLogger("Outer")
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
mic_enabled=True,
|
||||
mic_sample_rate=16000,
|
||||
camera_enabled=False
|
||||
)
|
||||
|
||||
# llm = AzureLLMService(api_key=os.getenv("AZURE_CHATGPT_API_KEY"), endpoint=os.getenv("AZURE_CHATGPT_ENDPOINT"), model=os.getenv("AZURE_CHATGPT_MODEL"))
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_CHATGPT_API_KEY"))
|
||||
# tts = AzureTTSService(api_key=os.getenv("AZURE_SPEECH_API_KEY"), region=os.getenv("AZURE_SPEECH_REGION"))
|
||||
tts = ElevenLabsTTSService(aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), voice_id="EXAVITQu4vr4xnSDxMaL")
|
||||
|
||||
async def handle_transcriptions():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
{"role": "system", "content": """You are Valerie, an agent for a company called Valorant Health. Your job is to help users get access to health care. You're talking to Chad Bailey, a 40 year old male who needs to see a doctor.
|
||||
|
||||
You need to do three things, in this order:
|
||||
|
||||
1. Confirm the user's identity.
|
||||
2. Find out what kinds of doctors the user needs to see.
|
||||
3. Get the name of their insurance company.
|
||||
|
||||
Start by introducing yourself and asking the user to verify their identity by providing their date of birth. Once their identity is confirmed, move on to step 2, then to step 3.
|
||||
|
||||
Once you have collected all of that information, respond with a JSON object containing the answers."""}
|
||||
]
|
||||
tma_in = LLMUserContextAggregator(messages, transport._my_participant_id)
|
||||
tma_out = LLMAssistantContextAggregator(messages, transport._my_participant_id)
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
fl2.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
fl.run(
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# checklist = ChecklistProcessor(messages, llm)
|
||||
|
||||
async def handle_transcriptions():
|
||||
tf = TranscriptFilter(transport._my_participant_id)
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
tf.run(
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
transport.transcription_settings["extra"]["endpointing"] = True
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
fl = FrameLogger("first other participant")
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
fl.run(
|
||||
tma_out.run(
|
||||
llm.run([LLMMessagesQueueFrame(messages)]),
|
||||
)
|
||||
)
|
||||
)
|
||||
transport.transcription_settings["extra"]["endpointing"] = True
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
120
src/examples/foundational/06a-multi-step.py
Normal file
120
src/examples/foundational/06a-multi-step.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.services.open_ai_services import OpenAILLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
|
||||
from examples.foundational.support.runner import configure
|
||||
from dailyai.queue_frame import LLMMessagesQueueFrame, TranscriptionQueueFrame, QueueFrame, TextQueueFrame
|
||||
from dailyai.services.ai_services import FrameLogger, AIService
|
||||
|
||||
class TranscriptFilter(AIService):
|
||||
def __init__(self, bot_participant_id=None):
|
||||
super().__init__()
|
||||
self.bot_participant_id = bot_participant_id
|
||||
print(f"Filtering transcripts from : {self.bot_participant_id}")
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, TranscriptionQueueFrame):
|
||||
if frame.participantId != self.bot_participant_id:
|
||||
yield frame
|
||||
|
||||
class ChecklistProcessor(AIService):
|
||||
def __init__(self, messages, llm, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._current_step = 0
|
||||
self._messages = messages
|
||||
self._llm = llm
|
||||
self._id = "You are Valerie, an agent for a company called Valorant Health. Your job is to help users get access to health care. You're talking to Chad Bailey, a 40 year old male who needs to see a doctor."
|
||||
self._steps = [
|
||||
"Start by introducing yourself. Then, ask the user to confirm their identity by telling you their birthday. After the user has confirmed their identity, respond only with ABC.",
|
||||
"Now that the user has confirmed their identity, ask them to describe what kind of doctor they need to see. When the user has responded with at least one kind of doctor, respond only with ABC.",
|
||||
"Next, you need to ask the user what kind of health insurance they have. Once the user has told you what insurance company they use, respond only with ABC.",
|
||||
"Tell the user goodbye.",
|
||||
""
|
||||
]
|
||||
messages.append({"role": "system", "content": f"{self._id} {self._steps[0]}"})
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, TextQueueFrame):
|
||||
print(f"got a text frame: {frame.text}")
|
||||
if isinstance(frame, TextQueueFrame) and frame.text == "ABC":
|
||||
self._current_step += 1
|
||||
# yield TextQueueFrame(f"We should move on to Step {self._current_step}.")
|
||||
self._messages.append({"role": "system", "content": self._steps[self._current_step]})
|
||||
yield LLMMessagesQueueFrame(self._messages)
|
||||
print(f"past llmmessagesqueueframe yield")
|
||||
async for frame in llm.process_frame(LLMMessagesQueueFrame(self._messages)):
|
||||
yield frame
|
||||
else:
|
||||
yield frame
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
transport.camera_enabled = False
|
||||
|
||||
# llm = AzureLLMService(api_key=os.getenv("AZURE_CHATGPT_API_KEY"), endpoint=os.getenv("AZURE_CHATGPT_ENDPOINT"), model=os.getenv("AZURE_CHATGPT_MODEL"))
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_CHATGPT_API_KEY"))
|
||||
# tts = AzureTTSService(api_key=os.getenv("AZURE_SPEECH_API_KEY"), region=os.getenv("AZURE_SPEECH_REGION"))
|
||||
tts = ElevenLabsTTSService(aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), voice_id="EXAVITQu4vr4xnSDxMaL")
|
||||
|
||||
messages = [
|
||||
]
|
||||
tma_in = LLMUserContextAggregator(messages, transport._my_participant_id)
|
||||
tma_out = LLMAssistantContextAggregator(messages, transport._my_participant_id)
|
||||
checklist = ChecklistProcessor(messages, llm)
|
||||
|
||||
async def handle_transcriptions():
|
||||
tf = TranscriptFilter(transport._my_participant_id)
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
checklist.run(
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
tf.run(
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
fl = FrameLogger("first other participant")
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
fl.run(
|
||||
tma_out.run(
|
||||
llm.run([LLMMessagesQueueFrame(messages)]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
483
src/examples/foundational/06b-patient-intake.py
Normal file
483
src/examples/foundational/06b-patient-intake.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import os
|
||||
import wave
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.services.open_ai_services import OpenAILLMService
|
||||
from dailyai.services.deepgram_ai_services import DeepgramTTSService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.queue_aggregators import LLMAssistantContextAggregator, LLMContextAggregator, LLMUserContextAggregator
|
||||
from support.runner import configure
|
||||
from dailyai.queue_frame import LLMMessagesQueueFrame, TranscriptionQueueFrame, QueueFrame, TextQueueFrame, LLMFunctionCallFrame, LLMResponseEndQueueFrame, StartStreamQueueFrame, AudioQueueFrame
|
||||
from dailyai.services.ai_services import FrameLogger, AIService
|
||||
from dailyai.conversation_wrappers import InterruptibleConversationWrapper
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
sounds = {}
|
||||
sound_files = [
|
||||
'clack-short.wav',
|
||||
'clack.wav',
|
||||
'clack-short-quiet.wav'
|
||||
]
|
||||
|
||||
script_dir = os.path.dirname(__file__)
|
||||
|
||||
for file in sound_files:
|
||||
# Build the full path to the image file
|
||||
full_path = os.path.join(script_dir, "assets", file)
|
||||
# Get the filename without the extension to use as the dictionary key
|
||||
filename = os.path.splitext(os.path.basename(full_path))[0]
|
||||
# Open the image and convert it to bytes
|
||||
with wave.open(full_path) as audio_file:
|
||||
sounds[file] = audio_file.readframes(-1)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "verify_birthday",
|
||||
"description": "Use this function to verify the user has provided their correct birthday.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {
|
||||
"type": "string",
|
||||
"description": "The user's birthdate, including the year. The user can provide it in any format, but convert it to YYYY-MM-DD format to call this function."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_prescriptions",
|
||||
"description": "Once the user has provided a list of their prescription medications, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prescriptions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The medication's name"
|
||||
},
|
||||
"dosage": {
|
||||
"type": "string",
|
||||
"description": "The prescription's dosage"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_allergies",
|
||||
"description": "Once the user has provided a list of their allergies, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"allergies": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "What the user is allergic to"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_conditions",
|
||||
"description": "Once the user has provided a list of their medical conditions, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The user's medical condition"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_visit_reasons",
|
||||
"description": "Once the user has provided a list of the reasons they are visiting a doctor today, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"visit_reasons": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The user's reason for visiting the doctor"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
steps = [
|
||||
{
|
||||
"prompt": "Start by introducing yourself. Then, ask the user to confirm their identity by telling you their birthday, including the year. When they answer with their birthday, call the verify_birthday function.",
|
||||
"run_async": False,
|
||||
"failed": "The user provided an incorrect birthday. Ask them for their birthday again. When they answer, call the verify_birthday function.", "tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "verify_birthday",
|
||||
"description": "Use this function to verify the user has provided their correct birthday.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {
|
||||
"type": "string",
|
||||
"description": "The user's birthdate, including the year. The user can provide it in any format, but convert it to YYYY-MM-DD format to call this function."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]},
|
||||
{
|
||||
"prompt": "Next, thank the user for confirming their identity, then ask the user to list their current prescriptions. Each prescription needs to have a medication name and a dosage. Do not call the list_prescriptions function with any unknown dosages.",
|
||||
"run_async": True,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_prescriptions",
|
||||
"description": "Once the user has provided a list of their prescription medications, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prescriptions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"medication": {
|
||||
"type": "string",
|
||||
"description": "The medication's name"
|
||||
},
|
||||
"dosage": {
|
||||
"type": "string",
|
||||
"description": "The prescription's dosage"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"prompt": "Next, ask the user if they have any allergies. Once they have listed their allergies or confirmed they don't have any, call the list_allergies function.",
|
||||
"run_async": True,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_allergies",
|
||||
"description": "Once the user has provided a list of their allergies, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"allergies": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "What the user is allergic to"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"prompt": "Now ask the user if they have any medical conditions the doctor should know about. Once they've answered the question, call the list_conditions function.",
|
||||
"run_async": True,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_conditions",
|
||||
"description": "Once the user has provided a list of their medical conditions, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The user's medical condition"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"prompt": "Finally, ask the user the reason for their doctor visit today. Once they answer, double-check to make sure they don't have any other health concerns. After that, call the list_visit_reasons function.",
|
||||
"run_async": True,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_visit_reasons",
|
||||
"description": "Once the user has provided a list of the reasons they are visiting a doctor today, call this function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"visit_reasons": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The user's reason for visiting the doctor"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"prompt": "Now, thank the user and end the conversation.", "run_async": True, "tools": []},
|
||||
{"prompt": "", "run_async": True, "tools": []}
|
||||
]
|
||||
current_step = 0
|
||||
|
||||
class TranscriptFilter(AIService):
|
||||
def __init__(self, bot_participant_id=None):
|
||||
super().__init__()
|
||||
self.bot_participant_id = bot_participant_id
|
||||
print(f"Filtering transcripts from : {self.bot_participant_id}")
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, TranscriptionQueueFrame):
|
||||
if frame.participantId != self.bot_participant_id:
|
||||
yield frame
|
||||
|
||||
|
||||
class ChecklistProcessor(AIService):
|
||||
def __init__(self, messages, llm, tools, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._messages = messages
|
||||
self._llm = llm
|
||||
self._tools = tools
|
||||
self._function_name = ""
|
||||
self._arguments = ""
|
||||
self._id = "You are Jessica, an agent for a company called Tri-County Advanced Optimum Health Solution Specialists. Your job is to collect important information from the user before they visit a doctor. You're talking to Chad Bailey. You should address the user by their first name and be polite and professional. You're not a medical professional, so you shouldn't provide any advice. Keep your responses short. Your job is to collect information to give to a doctor. Don't make assumptions about what values to plug into functions. Ask for clarification if a user response is ambiguous."
|
||||
self._acks = [ "One sec.", "Let me confirm that.", "Thanks.", "OK."]
|
||||
|
||||
messages.append(
|
||||
{"role": "system", "content": f"{self._id} {steps[0]['prompt']}"})
|
||||
|
||||
def verify_birthday(self, args):
|
||||
return args['birthday'] == "1983-08-19"
|
||||
|
||||
def list_prescriptions(self, args):
|
||||
print(f"Prescriptions: {args['prescriptions']}")
|
||||
|
||||
def list_allergies(self, args):
|
||||
print(f"Allergies: {args['allergies']}")
|
||||
|
||||
def list_conditions(self, args):
|
||||
print(f"Medical Conditions: {args['conditions']}")
|
||||
|
||||
def list_visit_reasons(self, args):
|
||||
print(f"Visit Reasons: {args['visit_reasons']}")
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
global current_step
|
||||
this_step = steps[current_step]
|
||||
# TODO-CB: forcing a global here :/
|
||||
self._tools.clear()
|
||||
self._tools.extend(this_step['tools'])
|
||||
if isinstance(frame, LLMFunctionCallFrame) and frame.function_name:
|
||||
print(f"FUNCTION CALL: {frame}")
|
||||
self._function_name = frame.function_name
|
||||
if this_step['run_async']:
|
||||
# Get the LLM talking about the next step before getting the rest
|
||||
# of the function call completion
|
||||
current_step += 1
|
||||
# yield TextQueueFrame(f"We should move on to Step {current_step}.")
|
||||
self._messages.append({
|
||||
"role": "system", "content": steps[current_step]['prompt']})
|
||||
# yield LLMMessagesQueueFrame(self._messages)
|
||||
yield LLMMessagesQueueFrame(self._messages)
|
||||
async for frame in llm.process_frame(LLMMessagesQueueFrame(self._messages), tool_choice="none"):
|
||||
yield frame
|
||||
else:
|
||||
# Insert a quick response while we run the function
|
||||
yield AudioQueueFrame(sounds["clack-short-quiet.wav"])
|
||||
elif isinstance(frame, LLMFunctionCallFrame) and frame.arguments:
|
||||
self._arguments += frame.arguments
|
||||
elif isinstance(frame, LLMResponseEndQueueFrame):
|
||||
print(
|
||||
f"%%% got a response end. function_name is {self._function_name}, arguments is {self._arguments}")
|
||||
print(f"%%%% messages is {self._messages}")
|
||||
|
||||
if self._function_name and self._arguments:
|
||||
|
||||
fn = getattr(self, self._function_name)
|
||||
print(f"fn is: {fn}")
|
||||
result = fn(json.loads(self._arguments))
|
||||
self._function_name = ""
|
||||
self._arguments = ""
|
||||
if not this_step['run_async']:
|
||||
if result:
|
||||
current_step += 1
|
||||
# yield TextQueueFrame(f"We should move on to Step {current_step}.")
|
||||
self._messages.append({
|
||||
"role": "system", "content": steps[current_step]['prompt']})
|
||||
# yield LLMMessagesQueueFrame(self._messages)
|
||||
yield LLMMessagesQueueFrame(self._messages)
|
||||
async for frame in llm.process_frame(LLMMessagesQueueFrame(self._messages), tool_choice="none"):
|
||||
yield frame
|
||||
else:
|
||||
self._messages.append({
|
||||
"role": "system", "content": this_step['failed']})
|
||||
# yield LLMMessagesQueueFrame(self._messages)
|
||||
yield LLMMessagesQueueFrame(self._messages)
|
||||
async for frame in llm.process_frame(LLMMessagesQueueFrame(self._messages), tool_choice="none"):
|
||||
yield frame
|
||||
print(f"VERIFY RESULT: {result}")
|
||||
|
||||
else:
|
||||
yield frame
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
mic_enabled=True,
|
||||
mic_sample_rate=16000,
|
||||
camera_enabled=False,
|
||||
start_transcription=True,
|
||||
vad_enabled=True
|
||||
)
|
||||
|
||||
messages = []
|
||||
tools = []
|
||||
|
||||
# llm = AzureLLMService(api_key=os.getenv("AZURE_CHATGPT_API_KEY"), endpoint=os.getenv("AZURE_CHATGPT_ENDPOINT"), model=os.getenv("AZURE_CHATGPT_MODEL"))
|
||||
llm = OpenAILLMService(api_key=os.getenv(
|
||||
"OPENAI_CHATGPT_API_KEY"), model="gpt-4-turbo-preview", tools=tools)
|
||||
# tts = AzureTTSService(api_key=os.getenv(
|
||||
# "AZURE_SPEECH_API_KEY"), region=os.getenv("AZURE_SPEECH_REGION"))
|
||||
tts = ElevenLabsTTSService(aiohttp_session=session, api_key=os.getenv(
|
||||
"ELEVENLABS_API_KEY"), voice_id="XrExE9yKIg1WjnnlVkGX") # matilda
|
||||
# tts = DeepgramTTSService(aiohttp_session=session, api_key=os.getenv("DEEPGRAM_API_KEY"), voice=os.getenv("DEEPGRAM_VOICE"))
|
||||
|
||||
tma_in = LLMUserContextAggregator(
|
||||
messages, transport._my_participant_id)
|
||||
tma_out = LLMAssistantContextAggregator(
|
||||
messages, transport._my_participant_id)
|
||||
checklist = ChecklistProcessor(messages, llm, tools)
|
||||
fl = FrameLogger("got transcript")
|
||||
fl2 = FrameLogger("just above the checklist")
|
||||
|
||||
async def run_response(user_speech, tma_in, tma_out):
|
||||
tf = TranscriptFilter(transport._my_participant_id)
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
checklist.run(
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
[StartStreamQueueFrame(), TextQueueFrame(user_speech)]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
fl = FrameLogger("first other participant")
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
fl.run(
|
||||
tma_out.run(
|
||||
llm.run([LLMMessagesQueueFrame(messages)]),
|
||||
)
|
||||
)
|
||||
)
|
||||
transport.transcription_settings["extra"]["endpointing"] = True
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
try:
|
||||
await asyncio.gather(transport.run(), transport.run_conversation(run_response))
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
print('whoops')
|
||||
transport.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -6,6 +6,7 @@ from dailyai.conversation_wrappers import InterruptibleConversationWrapper
|
||||
from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.services.open_ai_services import OpenAILLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
from examples.foundational.support.runner import configure
|
||||
|
||||
BIN
src/examples/foundational/assets/clack-short-quiet.wav
Normal file
BIN
src/examples/foundational/assets/clack-short-quiet.wav
Normal file
Binary file not shown.
BIN
src/examples/foundational/assets/clack-short.wav
Normal file
BIN
src/examples/foundational/assets/clack-short.wav
Normal file
Binary file not shown.
BIN
src/examples/foundational/assets/clack.wav
Normal file
BIN
src/examples/foundational/assets/clack.wav
Normal file
Binary file not shown.
Reference in New Issue
Block a user