387 lines
13 KiB
Python
387 lines
13 KiB
Python
#
|
||
# Copyright (c) 2024–2026, Daily
|
||
#
|
||
# SPDX-License-Identifier: BSD 2-Clause License
|
||
#
|
||
|
||
import asyncio
|
||
import io
|
||
import os
|
||
import re
|
||
import time
|
||
import wave
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import aiofiles
|
||
from loguru import logger
|
||
from PIL.ImageFile import ImageFile
|
||
from utils import (
|
||
EvalResult,
|
||
load_module_from_path,
|
||
print_begin_test,
|
||
print_end_test,
|
||
print_test_results,
|
||
)
|
||
|
||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||
from pipecat.frames.frames import (
|
||
CancelFrame,
|
||
EndFrame,
|
||
EndTaskFrame,
|
||
LLMRunFrame,
|
||
OutputImageRawFrame,
|
||
)
|
||
from pipecat.pipeline.pipeline import Pipeline
|
||
from pipecat.pipeline.runner import PipelineRunner
|
||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||
from pipecat.processors.aggregators.llm_response_universal import (
|
||
LLMContextAggregatorPair,
|
||
LLMUserAggregatorParams,
|
||
)
|
||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||
from pipecat.processors.frame_processor import FrameDirection
|
||
from pipecat.runner.types import RunnerArguments
|
||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||
from pipecat.services.llm_service import FunctionCallParams
|
||
from pipecat.services.openai.llm import OpenAILLMService
|
||
from pipecat.transports.daily.transport import DailyParams, DailyTransport
|
||
|
||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||
|
||
PIPELINE_IDLE_TIMEOUT_SECS = 60
|
||
EVAL_TIMEOUT_SECS = 120
|
||
EVAL_RESULT_TIMEOUT_SECS = 10
|
||
|
||
EvalPrompt = str | tuple[str, ImageFile]
|
||
|
||
|
||
@dataclass
|
||
class EvalConfig:
|
||
prompt: EvalPrompt
|
||
eval: str
|
||
eval_speaks_first: bool = False
|
||
runner_args_body: Any | None = None
|
||
|
||
|
||
class EvalRunner:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
examples_dir: Path,
|
||
pattern: str = "",
|
||
record_audio: bool = False,
|
||
name: str | None = None,
|
||
log_level: str = "DEBUG",
|
||
):
|
||
self._examples_dir = examples_dir
|
||
self._pattern = f".*{pattern}.*" if pattern else ""
|
||
self._record_audio = record_audio
|
||
self._log_level = log_level
|
||
self._total_success = 0
|
||
self._tests: list[EvalResult] = []
|
||
self._result_future: asyncio.Future[bool] | None = None
|
||
|
||
# We to save runner files.
|
||
name = name or f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
self._runs_dir = os.path.join(SCRIPT_DIR, "test-runs", name)
|
||
self._logs_dir = os.path.join(self._runs_dir, "logs")
|
||
self._recordings_dir = os.path.join(self._runs_dir, "recordings")
|
||
os.makedirs(self._logs_dir, exist_ok=True)
|
||
os.makedirs(self._recordings_dir, exist_ok=True)
|
||
|
||
async def function_assert_eval(self, params: FunctionCallParams):
|
||
result = params.arguments["result"]
|
||
reasoning = params.arguments["reasoning"]
|
||
logger.debug(f"🧠 EVAL REASONING(result: {result}): {reasoning}")
|
||
await params.result_callback(None)
|
||
await params.llm.push_frame(EndTaskFrame(reason=result), FrameDirection.UPSTREAM)
|
||
|
||
async def assert_eval(self, result: bool):
|
||
if self._result_future:
|
||
self._result_future.set_result(result)
|
||
|
||
async def run_eval(
|
||
self,
|
||
example_file: str,
|
||
eval_config: EvalConfig,
|
||
):
|
||
if not re.match(self._pattern, example_file):
|
||
return
|
||
|
||
# Store logs
|
||
filename = self._log_file_name(example_file)
|
||
log_file_id = logger.add(filename, level=self._log_level)
|
||
|
||
print_begin_test(example_file)
|
||
|
||
script_path = self._examples_dir / example_file
|
||
|
||
start_time = time.time()
|
||
|
||
# Create a future to store the eval result.
|
||
self._result_future = asyncio.get_running_loop().create_future()
|
||
|
||
try:
|
||
tasks = [
|
||
asyncio.create_task(run_example_pipeline(script_path, eval_config)),
|
||
asyncio.create_task(run_eval_pipeline(self, example_file, eval_config)),
|
||
]
|
||
_, pending = await asyncio.wait(tasks, timeout=EVAL_TIMEOUT_SECS)
|
||
if pending:
|
||
logger.error(f"ERROR: Eval timeout expired, cancelling pending tasks...")
|
||
# Both pipeline idle timeouts should have worked and both tasks
|
||
# should have exited already, but if we got here something went
|
||
# wrong so we perform an abrupt asyncio task cancellation, which
|
||
# will not cleanup things nicely.
|
||
for task in pending:
|
||
task.cancel()
|
||
await asyncio.gather(*pending, return_exceptions=True)
|
||
except Exception as e:
|
||
logger.error(f"ERROR: Unable to run {example_file}: {e}")
|
||
|
||
try:
|
||
# Wait for the future to resolve.
|
||
result = await asyncio.wait_for(self._result_future, timeout=EVAL_RESULT_TIMEOUT_SECS)
|
||
except TimeoutError:
|
||
logger.error(f"ERROR: Timeout waiting for eval result.")
|
||
result = False
|
||
|
||
if result:
|
||
self._total_success += 1
|
||
|
||
eval_time = time.time() - start_time
|
||
|
||
self._tests.append(EvalResult(name=example_file, result=result, time=eval_time))
|
||
|
||
print_end_test(example_file, result, eval_time)
|
||
|
||
logger.remove(log_file_id)
|
||
|
||
def print_results(self):
|
||
print_test_results(self._tests, self._total_success, self._runs_dir)
|
||
|
||
async def save_audio(self, name: str, audio: bytes, sample_rate: int, num_channels: int):
|
||
if len(audio) > 0:
|
||
filename = self._recording_file_name(name)
|
||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||
logger.debug(f"Saving {name} audio to {filename}")
|
||
with io.BytesIO() as buffer:
|
||
with wave.open(buffer, "wb") as wf:
|
||
wf.setsampwidth(2)
|
||
wf.setnchannels(num_channels)
|
||
wf.setframerate(sample_rate)
|
||
wf.writeframes(audio)
|
||
async with aiofiles.open(filename, "wb") as file:
|
||
await file.write(buffer.getvalue())
|
||
else:
|
||
logger.warning(f"There's no audio to save for {name}")
|
||
|
||
def _base_file_name(self, example_file: str):
|
||
base_name = os.path.splitext(example_file)[0]
|
||
return f"{base_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
|
||
def _log_file_name(self, example_file: str):
|
||
base_name = self._base_file_name(example_file)
|
||
return os.path.join(self._logs_dir, f"{base_name}.log")
|
||
|
||
def _recording_file_name(self, example_file: str):
|
||
base_name = self._base_file_name(example_file)
|
||
return os.path.join(self._recordings_dir, f"{base_name}.wav")
|
||
|
||
|
||
async def run_example_pipeline(script_path: Path, eval_config: EvalConfig):
|
||
room_url = os.environ["DAILY_ROOM_URL"]
|
||
|
||
module = load_module_from_path(script_path)
|
||
|
||
transport = DailyTransport(
|
||
room_url,
|
||
None,
|
||
"Pipecat",
|
||
DailyParams(
|
||
audio_in_enabled=True,
|
||
audio_out_enabled=True,
|
||
video_in_enabled=True,
|
||
),
|
||
)
|
||
|
||
runner_args = RunnerArguments()
|
||
runner_args.pipeline_idle_timeout_secs = PIPELINE_IDLE_TIMEOUT_SECS
|
||
runner_args.body = eval_config.runner_args_body
|
||
|
||
await module.run_bot(transport, runner_args)
|
||
|
||
|
||
async def run_eval_pipeline(
|
||
eval_runner: EvalRunner,
|
||
example_file: str,
|
||
eval_config: EvalConfig,
|
||
):
|
||
logger.info(f"Starting eval bot")
|
||
|
||
room_url = os.environ["DAILY_ROOM_URL"]
|
||
|
||
transport = DailyTransport(
|
||
room_url,
|
||
None,
|
||
"Pipecat Eval",
|
||
DailyParams(
|
||
audio_in_enabled=True,
|
||
audio_out_enabled=True,
|
||
video_out_enabled=True,
|
||
),
|
||
)
|
||
|
||
# We disable smart formatting because some times if the user says "3 + 2 is
|
||
# 5" (in audio) this can be converted to "32 is 5".
|
||
stt = DeepgramSTTService(
|
||
api_key=os.environ["DEEPGRAM_API_KEY"],
|
||
settings=DeepgramSTTService.Settings(
|
||
language="multi",
|
||
smart_format=False,
|
||
),
|
||
)
|
||
|
||
tts = CartesiaTTSService(
|
||
api_key=os.environ["CARTESIA_API_KEY"],
|
||
settings=CartesiaTTSService.Settings(
|
||
voice="97f4b8fb-f2fe-444b-bb9a-c109783a857a", # Nathan
|
||
),
|
||
)
|
||
|
||
eval_function = FunctionSchema(
|
||
name="eval_function",
|
||
description=(
|
||
"Determines whether the user's response satisfies the evaluation "
|
||
"criteria defined for the current prompt or interaction."
|
||
),
|
||
properties={
|
||
"result": {
|
||
"type": "boolean",
|
||
"description": "Whether the user's response meets the evaluation criteria.",
|
||
},
|
||
"reasoning": {
|
||
"type": "string",
|
||
"description": (
|
||
"A concise explanation of how the user's response did or did "
|
||
"not satisfy the evaluation criteria."
|
||
),
|
||
},
|
||
},
|
||
required=["result", "reasoning"],
|
||
)
|
||
tools = ToolsSchema(standard_tools=[eval_function])
|
||
|
||
# Load example prompt depending on image.
|
||
example_prompt = ""
|
||
example_image: ImageFile | None = None
|
||
if isinstance(eval_config.prompt, str):
|
||
example_prompt = eval_config.prompt
|
||
elif isinstance(eval_config.prompt, tuple):
|
||
example_prompt, example_image = eval_config.prompt
|
||
|
||
common_system_prompt = (
|
||
"You should only call the eval function if:\n"
|
||
"- The user explicitly attempts to answer the question, AND\n"
|
||
f"- Their answer can be cleanly evaluated using: {eval_config.eval}\n"
|
||
"Ignore greetings, comments, non-answers, or requests for clarification.\n"
|
||
"Numerical word answers are allowed (e.g., 'five' is the same as '5').\n"
|
||
)
|
||
if eval_config.eval_speaks_first:
|
||
system_prompt = f"You are an evaluation agent, be extremly brief. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||
else:
|
||
system_prompt = f"You are an evaluation agent, be extremly brief. First, ask one question: {example_prompt}. {common_system_prompt}"
|
||
|
||
llm = OpenAILLMService(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
settings=OpenAILLMService.Settings(
|
||
system_instruction=system_prompt,
|
||
),
|
||
)
|
||
|
||
llm.register_function("eval_function", eval_runner.function_assert_eval)
|
||
|
||
context = LLMContext(tools=tools)
|
||
context_aggregator = LLMContextAggregatorPair(
|
||
context,
|
||
user_params=LLMUserAggregatorParams(
|
||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=2.0)),
|
||
),
|
||
)
|
||
|
||
audio_buffer = AudioBufferProcessor()
|
||
|
||
pipeline = Pipeline(
|
||
[
|
||
transport.input(), # Transport user input
|
||
stt, # STT
|
||
context_aggregator.user(), # User responses
|
||
llm, # LLM
|
||
tts, # TTS
|
||
transport.output(), # Transport bot output
|
||
audio_buffer,
|
||
context_aggregator.assistant(), # Assistant spoken responses
|
||
]
|
||
)
|
||
|
||
task = PipelineTask(
|
||
pipeline,
|
||
params=PipelineParams(
|
||
audio_in_sample_rate=16000,
|
||
audio_out_sample_rate=16000,
|
||
),
|
||
enable_rtvi=False,
|
||
idle_timeout_secs=PIPELINE_IDLE_TIMEOUT_SECS,
|
||
)
|
||
|
||
@audio_buffer.event_handler("on_audio_data")
|
||
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||
await eval_runner.save_audio(example_file, audio, sample_rate, num_channels)
|
||
|
||
@transport.event_handler("on_client_connected")
|
||
async def on_client_connected(transport, client):
|
||
logger.info(f"Client connected")
|
||
if example_image:
|
||
await task.queue_frame(
|
||
OutputImageRawFrame(
|
||
image=example_image.tobytes(),
|
||
size=example_image.size,
|
||
format="RGB",
|
||
)
|
||
)
|
||
await audio_buffer.start_recording()
|
||
|
||
# Default behavior is for the bot to speak first
|
||
# If the eval bot speaks first, we append the prompt to the messages
|
||
if eval_config.eval_speaks_first:
|
||
context.add_message(
|
||
{"role": "user", "content": f"Start by saying this exactly: '{eval_config.prompt}'"}
|
||
)
|
||
await task.queue_frames([LLMRunFrame()])
|
||
|
||
@transport.event_handler("on_client_disconnected")
|
||
async def on_client_disconnected(transport, client):
|
||
logger.info(f"Client disconnected")
|
||
await task.cancel()
|
||
|
||
@task.event_handler("on_pipeline_finished")
|
||
async def on_pipeline_finished(task, frame):
|
||
if isinstance(frame, EndFrame):
|
||
await eval_runner.assert_eval(bool(frame.reason))
|
||
elif isinstance(frame, CancelFrame):
|
||
await eval_runner.assert_eval(False)
|
||
|
||
# TODO(aleix): We should handle SIGINT and SIGTERM so we can cancel both the
|
||
# eval and the example.
|
||
runner = PipelineRunner(handle_sigint=False)
|
||
|
||
await runner.run(task)
|