Compare commits
8 Commits
hush/hidde
...
greedy-plu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de4f3b6c44 | ||
|
|
4ed6648f99 | ||
|
|
dc5efe3028 | ||
|
|
07041cccce | ||
|
|
ce45a5f8bc | ||
|
|
4f1e9e2d50 | ||
|
|
b2c92c3225 | ||
|
|
2b324e4b01 |
@@ -2,6 +2,7 @@ autopep8~=2.1.0
|
||||
build~=1.2.1
|
||||
grpcio-tools~=1.62.2
|
||||
pip-tools~=7.4.1
|
||||
pyright~=1.1.367
|
||||
pytest~=8.2.0
|
||||
setuptools~=69.5.1
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
114
examples/foundational/tmp-khk-sonnet-ttft.py
Normal file
114
examples/foundational/tmp-khk-sonnet-ttft.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer()
|
||||
)
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_name=sys.argv[1] if len(sys.argv) > 1 else "British Lady"
|
||||
)
|
||||
|
||||
llm = AnthropicLLMService(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# todo: think more about how to handle system prompts in a more general way. OpenAI,
|
||||
# Google, and Anthropic all have slightly different approaches to providing a system
|
||||
# prompt.
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are participating in a friendly competition to invent creative "
|
||||
"new ice cream flavors. Say the craziest flavor you can think of "
|
||||
"then wait for your opponent to come up with a different crazy flavor. "
|
||||
"then respond with another flavor idea. Repeat forever. Say only the "
|
||||
"ice cream flavors and nothing else. End each ice cream flavor statement "
|
||||
"with an exclamation mark! Go ..."
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
# When a participant joins, start transcription for that participant so the
|
||||
# bot can "hear" and respond to them.
|
||||
@ transport.event_handler("on_participant_joined")
|
||||
async def on_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
|
||||
# When the first participant joins, the bot should introduce itself.
|
||||
@ transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
|
||||
# '{"action":"app-message","data":{"metrics":{"ttfb":[{"name":"AnthropicLLMService#0","time":0.5975627899169922}]},"type":"pipecat-metrics"},"fromId":"592d3489-90ba-401d-a760-c1a863d64a4a","callFrameId":"17189290998160.035120590426112264"}'
|
||||
# [Durian and Limburger Cheese Charcoal Activated Tar Twist!]
|
||||
# [Fermented Fish Sauce and Ghost Pepper Bubblegum Cotton Candy Nightmare!]
|
||||
# [Spoiled Yogurt and Ghost Pepper Gummy Bear Blizzard!]
|
||||
# [Matcha Green Tea and Sour Gummy Worm Fusion!]
|
||||
324
examples/foundational/tmp-khk.py
Normal file
324
examples/foundational/tmp-khk.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TextFrame,
|
||||
LLMMessagesFrame,
|
||||
TranscriptionFrame,
|
||||
InterimTranscriptionFrame,
|
||||
AudioRawFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSStoppedFrame
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.logger import FrameLogger
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.services.deepgram import DeepgramTTSService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext, OpenAILLMContextFrame
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport, DailyTransportMessageFrame
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState
|
||||
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class GreedyLLMAggregator(FrameProcessor):
|
||||
def __init__(self, context: OpenAILLMContext = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.context: OpenAILLMContext = context if context else OpenAILLMContext()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
logger.debug(f"{frame}")
|
||||
|
||||
try:
|
||||
if isinstance(frame, InterimTranscriptionFrame):
|
||||
return
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
# append transcribed text to last "user" frame
|
||||
if self.context.messages and self.context.messages[-1]["role"] == "user":
|
||||
last_frame = self.context.messages.pop()
|
||||
else:
|
||||
last_frame = {"role": "user", "content": ""}
|
||||
|
||||
last_frame["content"] += " " + frame.text
|
||||
self.context.messages.append(last_frame)
|
||||
|
||||
oai_context_frame = OpenAILLMContextFrame(context=self.context)
|
||||
logger.debug(f"pushing frame {oai_context_frame}")
|
||||
await self.push_frame(oai_context_frame)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.debug(f"error: {e}")
|
||||
|
||||
|
||||
class ClearableDeepgramTTSService(DeepgramTTSService):
|
||||
def __init___(self, **kwargs):
|
||||
super().__init(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._current_sentence = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferedSentence:
|
||||
audio_frames: List[AudioRawFrame] = field(default_factory=list)
|
||||
text_frame: TextFrame = None
|
||||
|
||||
|
||||
class VADGate(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vad_analyzer: VADAnalyzer = None,
|
||||
context: OpenAILLMContext = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vad_analyzer = vad_analyzer
|
||||
self.context = context
|
||||
|
||||
self._audio_pusher_task = None
|
||||
self._expect_text_frame_next = False
|
||||
self._sentences: List[BufferedSentence] = []
|
||||
|
||||
# queue output from tts one sentence at a time. associate a buffer of audio frames with the content of
|
||||
# each text frame.
|
||||
#
|
||||
# start a coroutine to service the queue and send sentences down the pipeline when possible.
|
||||
# 1. do not send anything when we are not in VADState.QUIET
|
||||
# 2. if we are in VADState.QUIET, send a sentence, estimate how long it will take for that sentence
|
||||
# to output, sleep until it's time to send another sentence
|
||||
# 3. each time we send a sentence, append it to the conversation context
|
||||
# 3. when the sentence buffer becomes empty, cancel the coroutine
|
||||
# 4. if we get a new LLMFullResponse, treat that as a cancellation, too
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
try:
|
||||
|
||||
# A TTSService will emit a series of AudioRawFrame objects, then a TTSStoppedFrame,
|
||||
# then a TextFrame.
|
||||
|
||||
if self._expect_text_frame_next:
|
||||
self._expect_text_frame_next = False
|
||||
if isinstance(frame, TextFrame):
|
||||
self._sentences[-1].text_frame = frame
|
||||
else:
|
||||
logger.debug(f"expected a text frame, but received {frame}")
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
# if our buffer is empty or has a "finished" sentence at the end,
|
||||
# then we need to start buffering a new sentence
|
||||
if not self._sentences or self._sentences[-1].text_frame:
|
||||
self._sentences.append(BufferedSentence())
|
||||
self._sentences[-1].audio_frames.append(frame)
|
||||
await self.maybe_start_audio_pusher_task()
|
||||
return
|
||||
|
||||
if isinstance(frame, TTSStoppedFrame):
|
||||
self._expect_text_frame_next = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# There are two ways we can be interrupted. During greedy inference, a new
|
||||
# LLM response can start. Or, during playout, we can get a traditional
|
||||
# user interruption frame.
|
||||
if (isinstance(frame, LLMFullResponseStartFrame) or
|
||||
isinstance(frame, StartInterruptionFrame)):
|
||||
logger.debug(f"{frame} - Handle interruption in VADGate")
|
||||
self._sentences = []
|
||||
if self._audio_pusher_task:
|
||||
self._audio_pusher_task.cancel()
|
||||
self._audio_pusher_task = None
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
except Exception as e:
|
||||
logger.debug(f"error: {e}")
|
||||
|
||||
async def maybe_start_audio_pusher_task(self):
|
||||
try:
|
||||
if self._audio_pusher_task:
|
||||
return
|
||||
self._audio_pusher_task = self.get_event_loop().create_task(self.push_audio())
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception {e}")
|
||||
|
||||
async def push_audio(self):
|
||||
try:
|
||||
while True:
|
||||
if not self._sentences:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
if self.vad_analyzer._vad_state != VADState.QUIET:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# we only want to push completed sentence buffers
|
||||
if not self._sentences[0].text_frame:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
s = self._sentences.pop(0)
|
||||
if not s.audio_frames:
|
||||
continue
|
||||
sample_rate = s.audio_frames[0].sample_rate
|
||||
duration = 0
|
||||
logger.debug(f"Pushing {len(s.audio_frames)} audio frames")
|
||||
for frame in s.audio_frames:
|
||||
await self.push_frame(frame)
|
||||
# assume linear16 encoding (2 bytes per sample). todo: add some more
|
||||
# metadata to AudioRawFrame, maybe
|
||||
duration += (len(frame.audio) / 2 / frame.num_channels) / sample_rate
|
||||
await asyncio.sleep(duration - 20 / 1000)
|
||||
if self.context:
|
||||
logger.debug(f"Appending assistant message to context: [{s.text_frame.text}]")
|
||||
if self.context.messages and self.context.messages[-1]["role"] == "assistant":
|
||||
self.context.messages[-1]["content"] += " " + s.text_frame.text
|
||||
else:
|
||||
self.context.messages.append(
|
||||
{"role": "assistant", "content": s.text_frame.text}
|
||||
)
|
||||
await self.push_frame(s.text_frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Exception {e}")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5))
|
||||
)
|
||||
)
|
||||
|
||||
tts = ClearableDeepgramTTSService(
|
||||
aiohttp_session=session,
|
||||
api_key=os.getenv("DEEPGRAM_API_KEY"),
|
||||
voice="aura-asteria-en",
|
||||
# base_url="http://0.0.0.0:8080/v1/speak"
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
# To use OpenAI
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o"
|
||||
# Or, to use a local vLLM (or similar) api server
|
||||
# model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
# model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8",
|
||||
# base_url="http://0.0.0.0:8000/v1"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM communicating via audio. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
ctx = OpenAILLMContext()
|
||||
greedy = GreedyLLMAggregator(name="greedy", context=ctx)
|
||||
gate = VADGate(name="gate", vad_analyzer=transport.input().vad_analyzer(), context=ctx)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
greedy,
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
gate,
|
||||
transport.output(), # Transport bot output
|
||||
# FrameLogger()
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
# When a participant joins, start transcription for that participant so the
|
||||
# bot can "hear" and respond to them.
|
||||
@ transport.event_handler("on_participant_joined")
|
||||
async def on_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
|
||||
# When the first participant joins, the bot should introduce itself.
|
||||
@ transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
# Handle "latency-ping" messages. The client will send app messages that look like
|
||||
# this:
|
||||
# { "latency-ping": { ts: <client-side timestamp> }}
|
||||
#
|
||||
# We want to send an immediate pong back to the client from this handler function.
|
||||
# Also, we will push a frame into the top of the pipeline and send it after the
|
||||
#
|
||||
@ transport.event_handler("on_app_message")
|
||||
async def on_app_message(transport, message, sender):
|
||||
try:
|
||||
if "latency-ping" in message:
|
||||
logger.debug(f"Received latency ping app message: {message}")
|
||||
ts = message["latency-ping"]["ts"]
|
||||
# Send immediately
|
||||
transport.output().send_message(DailyTransportMessageFrame(
|
||||
message={"latency-pong-msg-handler": {"ts": ts}},
|
||||
participant_id=sender))
|
||||
# And push to the pipeline for the Daily transport.output to send
|
||||
await tma_in.push_frame(
|
||||
DailyTransportMessageFrame(
|
||||
message={"latency-pong-pipeline-delivery": {"ts": ts}},
|
||||
participant_id=sender))
|
||||
except Exception as e:
|
||||
logger.debug(f"message handling error: {e} - {message}")
|
||||
|
||||
runner = PipelineRunner()
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -71,6 +71,8 @@ class PipelineTask:
|
||||
await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM)
|
||||
self._process_down_task.cancel()
|
||||
self._process_up_task.cancel()
|
||||
await self._process_down_task
|
||||
await self._process_up_task
|
||||
|
||||
async def run(self):
|
||||
self._process_up_task = asyncio.create_task(self._process_up_queue())
|
||||
@@ -122,6 +124,7 @@ class PipelineTask:
|
||||
await self._pipeline.cleanup()
|
||||
# We just enqueue None to terminate the task gracefully.
|
||||
self._process_up_task.cancel()
|
||||
await self._process_up_task
|
||||
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
|
||||
@@ -40,14 +40,17 @@ class AnthropicLLMService(LLMService):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "claude-3-opus-20240229",
|
||||
max_tokens: int = 1024):
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "claude-3-opus-20240229",
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self._client = AsyncAnthropic(api_key=api_key)
|
||||
self._model = model
|
||||
self._max_tokens = max_tokens
|
||||
self._temperature = temperature
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -110,6 +113,7 @@ class AnthropicLLMService(LLMService):
|
||||
messages=messages,
|
||||
model=self._model,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=self._temperature,
|
||||
stream=True)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -146,10 +146,11 @@ class AzureSTTService(AIService):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
|
||||
@@ -134,10 +134,11 @@ class DeepgramSTTService(AIService):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await self._connection.finish()
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
def _create_push_task(self):
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _push_frame_task_handler(self):
|
||||
running = True
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@@ -33,8 +32,6 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
self._params = params
|
||||
|
||||
self._running = False
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
@@ -42,34 +39,21 @@ class BaseInputTransport(FrameProcessor):
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Create audio input queue and thread if needed.
|
||||
# Create audio input queue and task if needed.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_queue = queue.Queue()
|
||||
self._audio_thread = self._loop.run_in_executor(
|
||||
self._executor, self._audio_thread_handler)
|
||||
self._audio_in_queue = asyncio.Queue()
|
||||
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# This will exit all threads.
|
||||
self._running = False
|
||||
|
||||
# Wait for the threads to finish.
|
||||
# Wait for the task to finish.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
await self._audio_thread
|
||||
|
||||
self._push_frame_task.cancel()
|
||||
self._audio_task.cancel()
|
||||
await self._audio_task
|
||||
|
||||
def vad_analyzer(self) -> VADAnalyzer | None:
|
||||
return self._params.vad_analyzer
|
||||
|
||||
def push_audio_frame(self, frame: AudioRawFrame):
|
||||
async def push_audio_frame(self, frame: AudioRawFrame):
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_queue.put_nowait(frame)
|
||||
|
||||
@@ -78,7 +62,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
pass
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -102,8 +87,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
@@ -129,6 +114,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
logger.debug("User started speaking")
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
self._create_push_task()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
@@ -140,15 +126,16 @@ class BaseInputTransport(FrameProcessor):
|
||||
# Audio input
|
||||
#
|
||||
|
||||
def _vad_analyze(self, audio_frames: bytes) -> VADState:
|
||||
async def _vad_analyze(self, audio_frames: bytes) -> VADState:
|
||||
state = VADState.QUIET
|
||||
vad_analyzer = self.vad_analyzer()
|
||||
if vad_analyzer:
|
||||
state = vad_analyzer.analyze_audio(audio_frames)
|
||||
state = await self.get_event_loop().run_in_executor(
|
||||
self._executor, vad_analyzer.analyze_audio, audio_frames)
|
||||
return state
|
||||
|
||||
def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
|
||||
new_vad_state = self._vad_analyze(audio_frames)
|
||||
async def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
|
||||
new_vad_state = await self._vad_analyze(audio_frames)
|
||||
if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING:
|
||||
frame = None
|
||||
if new_vad_state == VADState.SPEAKING:
|
||||
@@ -157,33 +144,29 @@ class BaseInputTransport(FrameProcessor):
|
||||
frame = UserStoppedSpeakingFrame()
|
||||
|
||||
if frame:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._handle_interruptions(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._handle_interruptions(frame)
|
||||
|
||||
vad_state = new_vad_state
|
||||
return vad_state
|
||||
|
||||
def _audio_thread_handler(self):
|
||||
async def _audio_task_handler(self):
|
||||
vad_state: VADState = VADState.QUIET
|
||||
while self._running:
|
||||
while True:
|
||||
try:
|
||||
frame: AudioRawFrame = self._audio_in_queue.get(timeout=1)
|
||||
frame: AudioRawFrame = await self._audio_in_queue.get()
|
||||
|
||||
audio_passthrough = True
|
||||
|
||||
# Check VAD and push event if necessary. We just care about
|
||||
# changes from QUIET to SPEAKING and vice versa.
|
||||
if self._params.vad_enabled:
|
||||
vad_state = self._handle_vad(frame.audio, vad_state)
|
||||
vad_state = await self._handle_vad(frame.audio, vad_state)
|
||||
audio_passthrough = self._params.vad_audio_passthrough
|
||||
|
||||
# Push audio downstream if passthrough.
|
||||
if audio_passthrough:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self._loop)
|
||||
future.result()
|
||||
except queue.Empty:
|
||||
pass
|
||||
await self._internal_push_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error reading audio frames: {e}")
|
||||
|
||||
@@ -7,11 +7,6 @@
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import queue
|
||||
import time
|
||||
import threading
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
@@ -42,67 +37,51 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
self._params = params
|
||||
|
||||
self._running = False
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
# These are the images that we should send to the camera at our desired
|
||||
# framerate.
|
||||
self._camera_images = None
|
||||
|
||||
# Create media threads queues.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_queue = queue.Queue()
|
||||
self._sink_queue = queue.Queue()
|
||||
self._sink_thread = None
|
||||
|
||||
self._stopped_event = asyncio.Event()
|
||||
self._is_interrupted = threading.Event()
|
||||
|
||||
# We will write 20ms audio at a time. If we receive long audio frames we
|
||||
# will chunk them. This will help with interruption handling.
|
||||
audio_bytes_10ms = int(self._params.audio_out_sample_rate / 100) * \
|
||||
self._params.audio_out_channels * 2
|
||||
self._audio_chunk_size = audio_bytes_10ms * 2
|
||||
|
||||
self._stopped_event = asyncio.Event()
|
||||
|
||||
# Create sink frame task. This is the task that will actually write
|
||||
# audio or video frames. We write audio/video in a task so we can keep
|
||||
# generating frames upstream while, for example, the audio is playing.
|
||||
self._create_sink_task()
|
||||
|
||||
# Create push frame task. This is the task that will push frames in
|
||||
# order. We also guarantee that all frames are pushed in the same task.
|
||||
self._create_push_task()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
loop = self.get_event_loop()
|
||||
|
||||
# Create queues and threads.
|
||||
# Create media threads queues.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_thread = loop.run_in_executor(
|
||||
self._executor, self._camera_out_thread_handler)
|
||||
|
||||
self._sink_thread = loop.run_in_executor(self._executor, self._sink_thread_handler)
|
||||
self._camera_out_queue = asyncio.Queue()
|
||||
self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# This will exit all threads.
|
||||
self._running = False
|
||||
# Wait on the threads to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
self._camera_out_task.cancel()
|
||||
await self._camera_out_task
|
||||
|
||||
self._stopped_event.set()
|
||||
|
||||
def send_message(self, frame: TransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
pass
|
||||
|
||||
def send_metrics(self, frame: MetricsFrame):
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
pass
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
pass
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
pass
|
||||
|
||||
#
|
||||
@@ -110,12 +89,12 @@ class BaseOutputTransport(FrameProcessor):
|
||||
#
|
||||
|
||||
async def cleanup(self):
|
||||
# Wait on the threads to finish.
|
||||
if self._params.camera_out_enabled:
|
||||
await self._camera_out_thread
|
||||
if self._sink_task:
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
|
||||
if self._sink_thread:
|
||||
await self._sink_thread
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -128,7 +107,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
# EndFrame is managed in the queue handler.
|
||||
# EndFrame is managed in the sink queue handler.
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.stop()
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -136,14 +115,14 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._handle_interruptions(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
self.send_metrics(frame)
|
||||
await self.send_metrics(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
else:
|
||||
self._sink_queue.put_nowait(frame)
|
||||
await self._sink_queue.put(frame)
|
||||
|
||||
# If we are finishing, wait here until we have stopped, otherwise we might
|
||||
# close things too early upstream. We need this event because we don't
|
||||
@@ -156,50 +135,51 @@ class BaseOutputTransport(FrameProcessor):
|
||||
return
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._is_interrupted.set()
|
||||
# Stop sink task.
|
||||
self._sink_task.cancel()
|
||||
await self._sink_task
|
||||
self._create_sink_task()
|
||||
# Stop push task.
|
||||
self._push_frame_task.cancel()
|
||||
await self._push_frame_task
|
||||
self._create_push_task()
|
||||
elif isinstance(frame, StopInterruptionFrame):
|
||||
self._is_interrupted.clear()
|
||||
|
||||
async def _handle_audio(self, frame: AudioRawFrame):
|
||||
audio = frame.audio
|
||||
for i in range(0, len(audio), self._audio_chunk_size):
|
||||
chunk = AudioRawFrame(audio[i: i + self._audio_chunk_size],
|
||||
sample_rate=frame.sample_rate, num_channels=frame.num_channels)
|
||||
self._sink_queue.put_nowait(chunk)
|
||||
await self._sink_queue.put(chunk)
|
||||
|
||||
def _sink_thread_handler(self):
|
||||
def _create_sink_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._sink_queue = asyncio.Queue()
|
||||
self._sink_task = loop.create_task(self._sink_task_handler())
|
||||
|
||||
async def _sink_task_handler(self):
|
||||
# Audio accumlation buffer
|
||||
buffer = bytearray()
|
||||
while self._running:
|
||||
while True:
|
||||
try:
|
||||
frame = self._sink_queue.get(timeout=1)
|
||||
if not self._is_interrupted.is_set():
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
buffer.extend(frame.audio)
|
||||
buffer = self._maybe_send_audio(buffer)
|
||||
elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled:
|
||||
self._set_camera_image(frame)
|
||||
elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled:
|
||||
self._set_camera_images(frame.images)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
self.send_message(frame)
|
||||
else:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
frame = await self._sink_queue.get()
|
||||
if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled:
|
||||
buffer.extend(frame.audio)
|
||||
buffer = await self._maybe_send_audio(buffer)
|
||||
elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled:
|
||||
await self._set_camera_image(frame)
|
||||
elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled:
|
||||
await self._set_camera_images(frame.images)
|
||||
elif isinstance(frame, TransportMessageFrame):
|
||||
await self.send_message(frame)
|
||||
else:
|
||||
# If we get interrupted just clear the output buffer.
|
||||
buffer = bytearray()
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
if isinstance(frame, EndFrame):
|
||||
future = asyncio.run_coroutine_threadsafe(self.stop(), self.get_event_loop())
|
||||
future.result()
|
||||
await self.stop()
|
||||
|
||||
self._sink_queue.task_done()
|
||||
except queue.Empty:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error processing sink queue: {e}")
|
||||
|
||||
@@ -209,8 +189,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
|
||||
def _create_push_task(self):
|
||||
loop = self.get_event_loop()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
self._push_queue = asyncio.Queue()
|
||||
self._push_frame_task = loop.create_task(self._push_frame_task_handler())
|
||||
|
||||
async def _internal_push_frame(
|
||||
self,
|
||||
@@ -233,7 +213,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def send_image(self, frame: ImageRawFrame | SpriteFrame):
|
||||
await self.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
def _draw_image(self, frame: ImageRawFrame):
|
||||
async def _draw_image(self, frame: ImageRawFrame):
|
||||
desired_size = (self._params.camera_out_width, self._params.camera_out_height)
|
||||
|
||||
if frame.size != desired_size:
|
||||
@@ -243,32 +223,32 @@ class BaseOutputTransport(FrameProcessor):
|
||||
f"{frame} does not have the expected size {desired_size}, resizing")
|
||||
frame = ImageRawFrame(resized_image.tobytes(), resized_image.size, resized_image.format)
|
||||
|
||||
self.write_frame_to_camera(frame)
|
||||
await self.write_frame_to_camera(frame)
|
||||
|
||||
def _set_camera_image(self, image: ImageRawFrame):
|
||||
async def _set_camera_image(self, image: ImageRawFrame):
|
||||
if self._params.camera_out_is_live:
|
||||
self._camera_out_queue.put_nowait(image)
|
||||
await self._camera_out_queue.put(image)
|
||||
else:
|
||||
self._camera_images = itertools.cycle([image])
|
||||
|
||||
def _set_camera_images(self, images: List[ImageRawFrame]):
|
||||
async def _set_camera_images(self, images: List[ImageRawFrame]):
|
||||
self._camera_images = itertools.cycle(images)
|
||||
|
||||
def _camera_out_thread_handler(self):
|
||||
while self._running:
|
||||
async def _camera_out_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
if self._params.camera_out_is_live:
|
||||
image = self._camera_out_queue.get(timeout=1)
|
||||
self._draw_image(image)
|
||||
image = await self._camera_out_queue.get()
|
||||
await self._draw_image(image)
|
||||
self._camera_out_queue.task_done()
|
||||
elif self._camera_images:
|
||||
image = next(self._camera_images)
|
||||
self._draw_image(image)
|
||||
time.sleep(1.0 / self._params.camera_out_framerate)
|
||||
await self._draw_image(image)
|
||||
await asyncio.sleep(1.0 / self._params.camera_out_framerate)
|
||||
else:
|
||||
time.sleep(1.0 / self._params.camera_out_framerate)
|
||||
except queue.Empty:
|
||||
pass
|
||||
await asyncio.sleep(1.0 / self._params.camera_out_framerate)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error writing to camera: {e}")
|
||||
|
||||
@@ -279,12 +259,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
async def send_audio(self, frame: AudioRawFrame):
|
||||
await self.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
def _maybe_send_audio(self, buffer: bytearray) -> bytearray:
|
||||
try:
|
||||
if len(buffer) >= self._audio_chunk_size:
|
||||
self.write_raw_audio_frames(bytes(buffer[:self._audio_chunk_size]))
|
||||
buffer = buffer[self._audio_chunk_size:]
|
||||
return buffer
|
||||
except BaseException as e:
|
||||
logger.error(f"{self} error writing audio frames: {e}")
|
||||
return buffer
|
||||
async def _maybe_send_audio(self, buffer: bytearray) -> bytearray:
|
||||
if len(buffer) >= self._audio_chunk_size:
|
||||
await self.write_raw_audio_frames(bytes(buffer[:self._audio_chunk_size]))
|
||||
buffer = buffer[self._audio_chunk_size:]
|
||||
return buffer
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, StartFrame
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
@@ -43,26 +45,20 @@ class LocalAudioInputTransport(BaseInputTransport):
|
||||
await super().start(frame)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
self._in_stream.stop_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._in_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._in_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._in_stream.close()
|
||||
|
||||
await super().cleanup()
|
||||
|
||||
def _audio_in_callback(self, in_data, frame_count, time_info, status):
|
||||
if not self._running:
|
||||
return (None, pyaudio.paAbort)
|
||||
|
||||
frame = AudioRawFrame(audio=in_data,
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
num_channels=self._params.audio_in_channels)
|
||||
self.push_audio_frame(frame)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop())
|
||||
|
||||
return (None, pyaudio.paContinue)
|
||||
|
||||
@@ -72,19 +68,29 @@ class LocalAudioOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams):
|
||||
super().__init__(params)
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
self._out_stream = py_audio.open(
|
||||
format=py_audio.get_format_from_width(2),
|
||||
channels=params.audio_out_channels,
|
||||
rate=params.audio_out_sample_rate,
|
||||
output=True)
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._out_stream.write(frames)
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._out_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._out_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._out_stream.close()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames)
|
||||
|
||||
|
||||
class LocalAudioTransport(BaseTransport):
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import tkinter as tk
|
||||
|
||||
@@ -53,25 +55,20 @@ class TkInputTransport(BaseInputTransport):
|
||||
await super().start(frame)
|
||||
self._in_stream.start_stream()
|
||||
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
self._in_stream.stop_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._in_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._in_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._in_stream.close()
|
||||
|
||||
def _audio_in_callback(self, in_data, frame_count, time_info, status):
|
||||
if not self._running:
|
||||
return (None, pyaudio.paAbort)
|
||||
|
||||
frame = AudioRawFrame(audio=in_data,
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
num_channels=self._params.audio_in_channels)
|
||||
self.push_audio_frame(frame)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop())
|
||||
|
||||
return (None, pyaudio.paContinue)
|
||||
|
||||
@@ -81,6 +78,8 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
def __init__(self, tk_root: tk.Tk, py_audio: pyaudio.PyAudio, params: TransportParams):
|
||||
super().__init__(params)
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
self._out_stream = py_audio.open(
|
||||
format=py_audio.get_format_from_width(2),
|
||||
channels=params.audio_out_channels,
|
||||
@@ -94,16 +93,24 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
self._image_label = tk.Label(tk_root, image=photo)
|
||||
self._image_label.pack()
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._out_stream.write(frames)
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self.get_event_loop().call_soon(self._write_frame_to_tk, frame)
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._out_stream.start_stream()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
self._out_stream.stop_stream()
|
||||
# This is not very pretty (taken from PyAudio docs).
|
||||
while self._out_stream.is_active():
|
||||
await asyncio.sleep(0.1)
|
||||
self._out_stream.close()
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames)
|
||||
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self.get_event_loop().call_soon(self._write_frame_to_tk, frame)
|
||||
|
||||
def _write_frame_to_tk(self, frame: ImageRawFrame):
|
||||
width = frame.size[0]
|
||||
height = frame.size[1]
|
||||
|
||||
@@ -88,7 +88,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
continue
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
self.push_audio_frame(frame)
|
||||
await self.push_audio_frame(frame)
|
||||
else:
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
@@ -118,7 +118,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
logger.warning("Only one client allowed, using new connection")
|
||||
self._websocket = websocket
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
self._audio_buffer += frames
|
||||
while len(self._audio_buffer) >= self._params.audio_frame_size:
|
||||
frame = AudioRawFrame(
|
||||
@@ -144,9 +144,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
|
||||
proto = self._params.serializer.serialize(frame)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._websocket.send(proto), self.get_event_loop())
|
||||
future.result()
|
||||
await self._websocket.send(proto)
|
||||
|
||||
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import queue
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Mapping
|
||||
from typing import Any, Awaitable, Callable, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from daily import (
|
||||
@@ -108,19 +107,26 @@ class DailyParams(TransportParams):
|
||||
|
||||
|
||||
class DailyCallbacks(BaseModel):
|
||||
on_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_left: Callable[[], None]
|
||||
on_error: Callable[[str], None]
|
||||
on_app_message: Callable[[Any, str], None]
|
||||
on_call_state_updated: Callable[[str], None]
|
||||
on_dialin_ready: Callable[[str], None]
|
||||
on_dialout_connected: Callable[[Any], None]
|
||||
on_dialout_stopped: Callable[[Any], None]
|
||||
on_dialout_error: Callable[[Any], None]
|
||||
on_dialout_warning: Callable[[Any], None]
|
||||
on_first_participant_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_participant_joined: Callable[[Mapping[str, Any]], None]
|
||||
on_participant_left: Callable[[Mapping[str, Any], str], None]
|
||||
on_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_left: Callable[[], Awaitable[None]]
|
||||
on_error: Callable[[str], Awaitable[None]]
|
||||
on_app_message: Callable[[Any, str], Awaitable[None]]
|
||||
on_call_state_updated: Callable[[str], Awaitable[None]]
|
||||
on_dialin_ready: Callable[[str], Awaitable[None]]
|
||||
on_dialout_connected: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_stopped: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_error: Callable[[Any], Awaitable[None]]
|
||||
on_dialout_warning: Callable[[Any], Awaitable[None]]
|
||||
on_first_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]]
|
||||
on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]]
|
||||
|
||||
|
||||
def completion_callback(future):
|
||||
def _callback(*args):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, *args)
|
||||
return _callback
|
||||
|
||||
|
||||
class DailyTransportClient(EventHandler):
|
||||
@@ -160,7 +166,6 @@ class DailyTransportClient(EventHandler):
|
||||
self._joined = False
|
||||
self._joining = False
|
||||
self._leaving = False
|
||||
self._sync_response = {k: queue.Queue() for k in ["join", "leave"]}
|
||||
|
||||
self._executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
@@ -173,10 +178,16 @@ class DailyTransportClient(EventHandler):
|
||||
color_format=self._params.camera_out_color_format)
|
||||
|
||||
self._mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self._params.audio_out_sample_rate, channels=self._params.audio_out_channels)
|
||||
"mic",
|
||||
sample_rate=self._params.audio_out_sample_rate,
|
||||
channels=self._params.audio_out_channels,
|
||||
non_blocking=True)
|
||||
|
||||
self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=self._params.audio_in_sample_rate, channels=self._params.audio_in_channels)
|
||||
"speaker",
|
||||
sample_rate=self._params.audio_in_sample_rate,
|
||||
channels=self._params.audio_in_channels,
|
||||
non_blocking=True)
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
@property
|
||||
@@ -186,30 +197,39 @@ class DailyTransportClient(EventHandler):
|
||||
def set_callbacks(self, callbacks: DailyCallbacks):
|
||||
self._callbacks = callbacks
|
||||
|
||||
def send_message(self, frame: DailyTransportMessageFrame):
|
||||
self._client.send_app_message(frame.message, frame.participant_id)
|
||||
async def send_message(self, frame: DailyTransportMessageFrame):
|
||||
future = self._loop.create_future()
|
||||
self._client.send_app_message(
|
||||
frame.message,
|
||||
frame.participant_id,
|
||||
completion=completion_callback(future))
|
||||
await future
|
||||
|
||||
def read_next_audio_frame(self) -> AudioRawFrame | None:
|
||||
async def read_next_audio_frame(self) -> AudioRawFrame | None:
|
||||
sample_rate = self._params.audio_in_sample_rate
|
||||
num_channels = self._params.audio_in_channels
|
||||
|
||||
if self._other_participant_has_joined:
|
||||
num_frames = int(sample_rate / 100) * 2 # 20ms of audio
|
||||
|
||||
audio = self._speaker.read_frames(num_frames)
|
||||
future = self._loop.create_future()
|
||||
self._speaker.read_frames(num_frames, completion=completion_callback(future))
|
||||
audio = await future
|
||||
|
||||
return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels)
|
||||
else:
|
||||
# If no one has ever joined the meeting `read_frames()` would block,
|
||||
# instead we just wait a bit. daily-python should probably return
|
||||
# silence instead.
|
||||
time.sleep(0.01)
|
||||
await asyncio.sleep(0.01)
|
||||
return None
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._mic.write_frames(frames)
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
future = self._loop.create_future()
|
||||
self._mic.write_frames(frames, completion=completion_callback(future))
|
||||
await future
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self._camera.write_frame(frame.image)
|
||||
|
||||
async def join(self):
|
||||
@@ -217,13 +237,10 @@ class DailyTransportClient(EventHandler):
|
||||
if self._joined or self._joining:
|
||||
return
|
||||
|
||||
self._joining = True
|
||||
|
||||
await self._loop.run_in_executor(self._executor, self._join)
|
||||
|
||||
def _join(self):
|
||||
logger.info(f"Joining {self._room_url}")
|
||||
|
||||
self._joining = True
|
||||
|
||||
# For performance reasons, never subscribe to video streams (unless a
|
||||
# video renderer is registered).
|
||||
self._client.update_subscription_profiles({
|
||||
@@ -235,10 +252,42 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
self._client.set_user_name(self._bot_name)
|
||||
|
||||
try:
|
||||
(data, error) = await self._join()
|
||||
|
||||
if not error:
|
||||
self._joined = True
|
||||
self._joining = False
|
||||
|
||||
logger.info(f"Joined {self._room_url}")
|
||||
|
||||
if self._token and self._params.transcription_enabled:
|
||||
logger.info(
|
||||
f"Enabling transcription with settings {self._params.transcription_settings}")
|
||||
self._client.start_transcription(
|
||||
self._params.transcription_settings.model_dump())
|
||||
|
||||
await self._callbacks.on_joined(data["participants"]["local"])
|
||||
else:
|
||||
error_msg = f"Error joining {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Time out joining {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
|
||||
async def _join(self):
|
||||
future = self._loop.create_future()
|
||||
|
||||
def handle_join_response(data, error):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, (data, error))
|
||||
|
||||
self._client.join(
|
||||
self._room_url,
|
||||
self._token,
|
||||
completion=self._call_joined,
|
||||
completion=handle_join_response,
|
||||
client_settings={
|
||||
"inputs": {
|
||||
"camera": {
|
||||
@@ -274,33 +323,7 @@ class DailyTransportClient(EventHandler):
|
||||
},
|
||||
})
|
||||
|
||||
self._handle_join_response()
|
||||
|
||||
def _handle_join_response(self):
|
||||
try:
|
||||
(data, error) = self._sync_response["join"].get(timeout=10)
|
||||
if not error:
|
||||
self._joined = True
|
||||
self._joining = False
|
||||
|
||||
logger.info(f"Joined {self._room_url}")
|
||||
|
||||
if self._token and self._params.transcription_enabled:
|
||||
logger.info(
|
||||
f"Enabling transcription with settings {self._params.transcription_settings}")
|
||||
self._client.start_transcription(
|
||||
self._params.transcription_settings.model_dump())
|
||||
|
||||
self._callbacks.on_joined(data["participants"]["local"])
|
||||
else:
|
||||
error_msg = f"Error joining {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
self._sync_response["join"].task_done()
|
||||
except queue.Empty:
|
||||
error_msg = f"Time out joining {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
return await asyncio.wait_for(future, timeout=10)
|
||||
|
||||
async def leave(self):
|
||||
# Transport not joined, ignore.
|
||||
@@ -310,34 +333,36 @@ class DailyTransportClient(EventHandler):
|
||||
self._joined = False
|
||||
self._leaving = True
|
||||
|
||||
await self._loop.run_in_executor(self._executor, self._leave)
|
||||
|
||||
def _leave(self):
|
||||
logger.info(f"Leaving {self._room_url}")
|
||||
|
||||
if self._params.transcription_enabled:
|
||||
self._client.stop_transcription()
|
||||
|
||||
self._client.leave(completion=self._call_left)
|
||||
|
||||
self._handle_leave_response()
|
||||
|
||||
def _handle_leave_response(self):
|
||||
try:
|
||||
error = self._sync_response["leave"].get(timeout=10)
|
||||
error = await self._leave()
|
||||
if not error:
|
||||
self._leaving = False
|
||||
logger.info(f"Left {self._room_url}")
|
||||
self._callbacks.on_left()
|
||||
await self._callbacks.on_left()
|
||||
else:
|
||||
error_msg = f"Error leaving {self._room_url}: {error}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
self._sync_response["leave"].task_done()
|
||||
except queue.Empty:
|
||||
await self._callbacks.on_error(error_msg)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Time out leaving {self._room_url}"
|
||||
logger.error(error_msg)
|
||||
self._callbacks.on_error(error_msg)
|
||||
await self._callbacks.on_error(error_msg)
|
||||
|
||||
async def _leave(self):
|
||||
future = self._loop.create_future()
|
||||
|
||||
def handle_leave_response(error):
|
||||
if not future.cancelled():
|
||||
future.get_loop().call_soon_threadsafe(future.set_result, error)
|
||||
|
||||
self._client.leave(completion=handle_leave_response)
|
||||
|
||||
return await asyncio.wait_for(future, timeout=10)
|
||||
|
||||
async def cleanup(self):
|
||||
await self._loop.run_in_executor(self._executor, self._cleanup)
|
||||
@@ -399,25 +424,25 @@ class DailyTransportClient(EventHandler):
|
||||
#
|
||||
|
||||
def on_app_message(self, message: Any, sender: str):
|
||||
self._callbacks.on_app_message(message, sender)
|
||||
self._call_async_callback(self._callbacks.on_app_message, message, sender)
|
||||
|
||||
def on_call_state_updated(self, state: str):
|
||||
self._callbacks.on_call_state_updated(state)
|
||||
self._call_async_callback(self._callbacks.on_call_state_updated, state)
|
||||
|
||||
def on_dialin_ready(self, sip_endpoint: str):
|
||||
self._callbacks.on_dialin_ready(sip_endpoint)
|
||||
self._call_async_callback(self._callbacks.on_dialin_ready, sip_endpoint)
|
||||
|
||||
def on_dialout_connected(self, data: Any):
|
||||
self._callbacks.on_dialout_connected(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_connected, data)
|
||||
|
||||
def on_dialout_stopped(self, data: Any):
|
||||
self._callbacks.on_dialout_stopped(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_stopped, data)
|
||||
|
||||
def on_dialout_error(self, data: Any):
|
||||
self._callbacks.on_dialout_error(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_error, data)
|
||||
|
||||
def on_dialout_warning(self, data: Any):
|
||||
self._callbacks.on_dialout_warning(data)
|
||||
self._call_async_callback(self._callbacks.on_dialout_warning, data)
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
id = participant["id"]
|
||||
@@ -425,15 +450,15 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
if not self._other_participant_has_joined:
|
||||
self._other_participant_has_joined = True
|
||||
self._callbacks.on_first_participant_joined(participant)
|
||||
self._call_async_callback(self._callbacks.on_first_participant_joined, participant)
|
||||
|
||||
self._callbacks.on_participant_joined(participant)
|
||||
self._call_async_callback(self._callbacks.on_participant_joined, participant)
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
id = participant["id"]
|
||||
logger.info(f"Participant left {id}")
|
||||
|
||||
self._callbacks.on_participant_left(participant, reason)
|
||||
self._call_async_callback(self._callbacks.on_participant_left, participant, reason)
|
||||
|
||||
def on_transcription_message(self, message: Mapping[str, Any]):
|
||||
participant_id = ""
|
||||
@@ -442,7 +467,7 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
if participant_id in self._transcription_renderers:
|
||||
callback = self._transcription_renderers[participant_id]
|
||||
callback(participant_id, message)
|
||||
self._call_async_callback(callback, participant_id, message)
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
logger.error(f"Transcription error: {message}")
|
||||
@@ -457,18 +482,19 @@ class DailyTransportClient(EventHandler):
|
||||
# Daily (CallClient callbacks)
|
||||
#
|
||||
|
||||
def _call_joined(self, data, error):
|
||||
self._sync_response["join"].put((data, error))
|
||||
|
||||
def _call_left(self, error):
|
||||
self._sync_response["leave"].put(error)
|
||||
|
||||
def _video_frame_received(self, participant_id, video_frame):
|
||||
callback = self._video_renderers[participant_id]
|
||||
callback(participant_id,
|
||||
video_frame.buffer,
|
||||
(video_frame.width, video_frame.height),
|
||||
video_frame.color_format)
|
||||
self._call_async_callback(
|
||||
callback,
|
||||
participant_id,
|
||||
video_frame.buffer,
|
||||
(video_frame.width,
|
||||
video_frame.height),
|
||||
video_frame.color_format)
|
||||
|
||||
def _call_async_callback(self, callback, *args):
|
||||
future = asyncio.run_coroutine_threadsafe(callback(*args), self._loop)
|
||||
future.result()
|
||||
|
||||
|
||||
class DailyInputTransport(BaseInputTransport):
|
||||
@@ -487,8 +513,6 @@ class DailyInputTransport(BaseInputTransport):
|
||||
num_channels=self._params.audio_in_channels)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
# Join the room.
|
||||
@@ -496,19 +520,17 @@ class DailyInputTransport(BaseInputTransport):
|
||||
# Create audio task. It reads audio frames from Daily and push them
|
||||
# internally for VAD processing.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
self._audio_in_thread = self._loop.run_in_executor(
|
||||
self._executor, self._audio_in_thread_handler)
|
||||
self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler())
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
# Parent stop. This will set _running to False.
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
# Stop audio thread.
|
||||
if self._params.audio_in_enabled or self._params.vad_enabled:
|
||||
await self._audio_in_thread
|
||||
self._audio_in_task.cancel()
|
||||
await self._audio_in_task
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
@@ -531,26 +553,25 @@ class DailyInputTransport(BaseInputTransport):
|
||||
# Frames
|
||||
#
|
||||
|
||||
def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame):
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
def push_app_message(self, message: Any, sender: str):
|
||||
async def push_app_message(self, message: Any, sender: str):
|
||||
frame = DailyTransportMessageFrame(message=message, participant_id=sender)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
#
|
||||
# Audio in
|
||||
#
|
||||
|
||||
def _audio_in_thread_handler(self):
|
||||
while self._running:
|
||||
frame = self._client.read_next_audio_frame()
|
||||
if frame:
|
||||
self.push_audio_frame(frame)
|
||||
async def _audio_in_task_handler(self):
|
||||
while True:
|
||||
try:
|
||||
frame = await self._client.read_next_audio_frame()
|
||||
if frame:
|
||||
await self.push_audio_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
#
|
||||
# Camera in
|
||||
@@ -580,7 +601,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
if participant_id in self._video_renderers:
|
||||
self._video_renderers[participant_id]["render_next_frame"] = True
|
||||
|
||||
def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
async def _on_participant_video_frame(self, participant_id: str, buffer, size, format):
|
||||
render_frame = False
|
||||
|
||||
curr_time = time.time()
|
||||
@@ -600,9 +621,7 @@ class DailyInputTransport(BaseInputTransport):
|
||||
image=buffer,
|
||||
size=size,
|
||||
format=format)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._internal_push_frame(frame), self.get_event_loop())
|
||||
future.result()
|
||||
await self._internal_push_frame(frame)
|
||||
|
||||
self._video_renderers[participant_id]["timestamp"] = curr_time
|
||||
|
||||
@@ -615,17 +634,13 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
self._client = client
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
if self._running:
|
||||
return
|
||||
# Parent start.
|
||||
await super().start(frame)
|
||||
# Join the room.
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
if not self._running:
|
||||
return
|
||||
# Parent stop. This will set _running to False.
|
||||
# Parent stop.
|
||||
await super().stop()
|
||||
# Leave the room.
|
||||
await self._client.leave()
|
||||
@@ -634,10 +649,10 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
def send_message(self, frame: DailyTransportMessageFrame):
|
||||
self._client.send_message(frame)
|
||||
async def send_message(self, frame: DailyTransportMessageFrame):
|
||||
await self._client.send_message(frame)
|
||||
|
||||
def send_metrics(self, frame: MetricsFrame):
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
ttfb = [{"name": n, "time": t} for n, t in frame.ttfb.items()]
|
||||
message = DailyTransportMessageFrame(message={
|
||||
"type": "pipecat-metrics",
|
||||
@@ -645,13 +660,13 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
"ttfb": ttfb
|
||||
},
|
||||
})
|
||||
self._client.send_message(message)
|
||||
await self._client.send_message(message)
|
||||
|
||||
def write_raw_audio_frames(self, frames: bytes):
|
||||
self._client.write_raw_audio_frames(frames)
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
await self._client.write_raw_audio_frames(frames)
|
||||
|
||||
def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
self._client.write_frame_to_camera(frame)
|
||||
async def write_frame_to_camera(self, frame: ImageRawFrame):
|
||||
await self._client.write_frame_to_camera(frame)
|
||||
|
||||
|
||||
class DailyTransport(BaseTransport):
|
||||
@@ -768,24 +783,24 @@ class DailyTransport(BaseTransport):
|
||||
self._input.capture_participant_video(
|
||||
participant_id, framerate, video_source, color_format)
|
||||
|
||||
def _on_joined(self, participant):
|
||||
self._call_async_event_handler("on_joined", participant)
|
||||
async def _on_joined(self, participant):
|
||||
await self._call_event_handler("on_joined", participant)
|
||||
|
||||
def _on_left(self):
|
||||
self._call_async_event_handler("on_left")
|
||||
async def _on_left(self):
|
||||
await self._call_event_handler("on_left")
|
||||
|
||||
def _on_error(self, error):
|
||||
async def _on_error(self, error):
|
||||
# TODO(aleix): Report error to input/output transports. The one managing
|
||||
# the client should report the error.
|
||||
pass
|
||||
|
||||
def _on_app_message(self, message: Any, sender: str):
|
||||
async def _on_app_message(self, message: Any, sender: str):
|
||||
if self._input:
|
||||
self._input.push_app_message(message, sender)
|
||||
self._call_async_event_handler("on_app_message", message, sender)
|
||||
await self._input.push_app_message(message, sender)
|
||||
await self._call_event_handler("on_app_message", message, sender)
|
||||
|
||||
def _on_call_state_updated(self, state: str):
|
||||
self._call_async_event_handler("on_call_state_updated", state)
|
||||
async def _on_call_state_updated(self, state: str):
|
||||
await self._call_event_handler("on_call_state_updated", state)
|
||||
|
||||
async def _handle_dialin_ready(self, sip_endpoint: str):
|
||||
if not self._params.dialin_settings:
|
||||
@@ -818,33 +833,33 @@ class DailyTransport(BaseTransport):
|
||||
except BaseException as e:
|
||||
logger.error(f"Error handling dialin-ready event ({url}): {e}")
|
||||
|
||||
def _on_dialin_ready(self, sip_endpoint):
|
||||
async def _on_dialin_ready(self, sip_endpoint):
|
||||
if self._params.dialin_settings:
|
||||
asyncio.run_coroutine_threadsafe(self._handle_dialin_ready(sip_endpoint), self._loop)
|
||||
self._call_async_event_handler("on_dialin_ready", sip_endpoint)
|
||||
await self._handle_dialin_ready(sip_endpoint)
|
||||
await self._call_event_handler("on_dialin_ready", sip_endpoint)
|
||||
|
||||
def _on_dialout_connected(self, data):
|
||||
self._call_async_event_handler("on_dialout_connected", data)
|
||||
async def _on_dialout_connected(self, data):
|
||||
await self._call_event_handler("on_dialout_connected", data)
|
||||
|
||||
def _on_dialout_stopped(self, data):
|
||||
self._call_async_event_handler("on_dialout_stopped", data)
|
||||
async def _on_dialout_stopped(self, data):
|
||||
await self._call_event_handler("on_dialout_stopped", data)
|
||||
|
||||
def _on_dialout_error(self, data):
|
||||
self._call_async_event_handler("on_dialout_error", data)
|
||||
async def _on_dialout_error(self, data):
|
||||
await self._call_event_handler("on_dialout_error", data)
|
||||
|
||||
def _on_dialout_warning(self, data):
|
||||
self._call_async_event_handler("on_dialout_warning", data)
|
||||
async def _on_dialout_warning(self, data):
|
||||
await self._call_event_handler("on_dialout_warning", data)
|
||||
|
||||
def _on_participant_joined(self, participant):
|
||||
self._call_async_event_handler("on_participant_joined", participant)
|
||||
async def _on_participant_joined(self, participant):
|
||||
await self._call_event_handler("on_participant_joined", participant)
|
||||
|
||||
def _on_participant_left(self, participant, reason):
|
||||
self._call_async_event_handler("on_participant_left", participant, reason)
|
||||
async def _on_participant_left(self, participant, reason):
|
||||
await self._call_event_handler("on_participant_left", participant, reason)
|
||||
|
||||
def _on_first_participant_joined(self, participant):
|
||||
self._call_async_event_handler("on_first_participant_joined", participant)
|
||||
async def _on_first_participant_joined(self, participant):
|
||||
await self._call_event_handler("on_first_participant_joined", participant)
|
||||
|
||||
def _on_transcription_message(self, participant_id, message):
|
||||
async def _on_transcription_message(self, participant_id, message):
|
||||
text = message["text"]
|
||||
timestamp = message["timestamp"]
|
||||
is_final = message["rawResponse"]["is_final"]
|
||||
@@ -855,9 +870,4 @@ class DailyTransport(BaseTransport):
|
||||
frame = InterimTranscriptionFrame(text, participant_id, timestamp)
|
||||
|
||||
if self._input:
|
||||
self._input.push_transcription_frame(frame)
|
||||
|
||||
def _call_async_event_handler(self, event_name: str, *args, **kwargs):
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._call_event_handler(event_name, *args, **kwargs), self._loop)
|
||||
future.result()
|
||||
await self._input.push_transcription_frame(frame)
|
||||
|
||||
86
tests/vllm-inference-test.py
Normal file
86
tests/vllm-inference-test.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=4096
|
||||
)
|
||||
|
||||
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.<|eot_id|><|start_header_id|>system<|end_header_id|>\n\nPlease introduce yourself to the user.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
async def main():
|
||||
print("🥶 cold starting inference")
|
||||
start = time.monotonic_ns()
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.90,
|
||||
enforce_eager=False, # False means slower starts but faster inference
|
||||
disable_log_stats=True, # disable logging so we can stream tokens
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
duration_s = (time.monotonic_ns() - start) / 1e9
|
||||
print(f"🏎️ engine started in {duration_s:.0f}s")
|
||||
|
||||
request_id = random_uuid()
|
||||
result_generator = engine.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
index, num_tokens = 0, 0
|
||||
start = time.monotonic_ns()
|
||||
async for output in result_generator:
|
||||
if (
|
||||
output.outputs[0].text
|
||||
and "\ufffd" == output.outputs[0].text[-1]
|
||||
):
|
||||
continue
|
||||
text_delta = output.outputs[0].text[index:]
|
||||
index = len(output.outputs[0].text)
|
||||
num_tokens = len(output.outputs[0].token_ids)
|
||||
|
||||
print(text_delta)
|
||||
duration_s = (time.monotonic_ns() - start) / 1e9
|
||||
|
||||
print(
|
||||
f"\n\tGenerated {num_tokens} tokens in {duration_s:.1f}s,"
|
||||
f" throughput = {num_tokens / duration_s:.0f} tokens/second.\n"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def xmain():
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
enable_prefix_caching=True
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompt, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
outputs = llm.generate(prompt, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user