Compare commits

...

7 Commits

Author SHA1 Message Date
James Hush
1884ff3f09 logging 2024-11-27 19:38:37 +08:00
James Hush
f34e6bce94 Switch questions 2024-11-27 15:10:50 +08:00
James Hush
909bb30517 Better recreation 2024-11-27 14:08:01 +08:00
James Hush
632bae7eee Interrupted? 2024-11-27 12:21:45 +08:00
James Hush
cedccdcbc0 Add interruptions 2024-11-27 11:50:28 +08:00
James Hush
1893784b89 Save race bot 2024-11-27 11:36:28 +08:00
James Hush
e2384e2484 fix: add logging and error handling for issue #721 2024-11-26 11:22:58 +08:00
5 changed files with 221 additions and 1 deletions

View File

@@ -10,11 +10,12 @@ import os
import sys
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.frames.frames import BotSpeakingFrame, Frame, InputAudioRawFrame, LLMMessagesFrame, TTSAudioRawFrame, TextFrame, UserStoppedSpeakingFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -30,6 +31,22 @@ load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class DebugProcessor(FrameProcessor):
def __init__(self, name, **kwargs):
self._name = name
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if not (
isinstance(frame, InputAudioRawFrame)
or isinstance(frame, BotSpeakingFrame)
or isinstance(frame, TTSAudioRawFrame)
or isinstance(frame, TextFrame)
):
logger.debug(f"--- {self._name}: {frame} {direction}")
await self.push_frame(frame, direction)
async def main():
async with aiohttp.ClientSession() as session:
@@ -63,11 +80,14 @@ async def main():
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
dp = DebugProcessor("dp")
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
dp,
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output

View File

@@ -0,0 +1,191 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import time
import aiohttp
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotSpeakingFrame,
EndFrame,
Frame,
InputAudioRawFrame,
StartInterruptionFrame,
StopInterruptionFrame,
TextFrame,
TranscriptionFrame,
TTSAudioRawFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class DebugProcessor(FrameProcessor):
def __init__(self, name, **kwargs):
self._name = name
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if not (
isinstance(frame, InputAudioRawFrame)
or isinstance(frame, BotSpeakingFrame)
or isinstance(frame, UserStoppedSpeakingFrame)
or isinstance(frame, TTSAudioRawFrame)
or isinstance(frame, TextFrame)
):
logger.debug(f"--- {self._name}: {frame} {direction}")
await self.push_frame(frame, direction)
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url,
None,
"AI Bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(api_key=os.environ["OPENAI_API_KEY"], model="gpt-4o")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
dp = DebugProcessor("dp")
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
runner = PipelineRunner()
task = PipelineTask(
Pipeline(
[
# transport.input(),
context_aggregator.user(),
llm,
dp,
tts,
transport.output(),
context_aggregator.assistant(),
]
),
PipelineParams(
allow_interruptions=True,
),
)
# Register an event handler so we can play the audio when the
# participant joins.
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
participant_id = participant.get("info", {}).get("participantId", "")
# Create frames for 600 seconds
start_time = time.time()
while time.time() - start_time < 300:
elapsed_time = round(time.time() - start_time)
logger.info(f"Running for {elapsed_time} seconds")
await task.queue_frame(
StartInterruptionFrame(),
)
await asyncio.sleep(1)
await task.queue_frame(
UserStartedSpeakingFrame(),
)
await asyncio.sleep(1)
await task.queue_frame(
TranscriptionFrame("Tell me more about your company.", participant_id, time.time()),
)
await asyncio.sleep(1)
await task.queue_frame(
StopInterruptionFrame(),
)
await asyncio.sleep(1)
await task.queue_frame(
UserStoppedSpeakingFrame(),
)
await asyncio.sleep(5)
await task.queue_frame(StartInterruptionFrame())
await asyncio.sleep(1)
await task.queue_frame(
UserStartedSpeakingFrame(),
)
await asyncio.sleep(1)
await task.queue_frame(
TranscriptionFrame("Give me a list of appointment dates.", participant_id, time.time()),
)
await asyncio.sleep(1)
await task.queue_frames(
StopInterruptionFrame(),
)
await asyncio.sleep(1)
await task.queue_frame(
UserStoppedSpeakingFrame(),
)
await asyncio.sleep(5)
await task.queue_frame(EndFrame())
# @transport.event_handler("on_first_participant_joined")
# async def on_first_participant_joined(transport, participant):
# await transport.capture_participant_transcription(participant["id"])
# # Kick off the conversation.
# messages.append({"role": "system", "content": "Please introduce yourself to the user."})
# await task.queue_frames([LLMMessagesFrame(messages)])
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -7,6 +7,7 @@
import asyncio
import base64
import json
import random
import uuid
from typing import AsyncGenerator, List, Optional, Union
@@ -222,6 +223,10 @@ class CartesiaTTSService(WordTTSService):
async def _receive_task_handler(self):
try:
async for message in self._get_websocket():
# Randomly cancel the asyncio task 1% of the time
if random.random() < 0.01:
logger.info(f"Cancelling task for {self} due to random chance")
asyncio.current_task().cancel()
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
@@ -256,6 +261,7 @@ class CartesiaTTSService(WordTTSService):
logger.error(f"Cartesia error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
# await self.push_error(ErrorFrame(f"{self} cancelled", True))
except Exception as e:
logger.error(f"{self} exception: {e}")

View File

@@ -71,6 +71,7 @@ class BaseInputTransport(FrameProcessor):
return self._params.vad_analyzer
async def push_audio_frame(self, frame: InputAudioRawFrame):
logger.info(f"Pushing audio qsize: {self._audio_in_queue.qsize()}")
if self._params.audio_in_enabled or self._params.vad_enabled:
await self._audio_in_queue.put(frame)
@@ -167,6 +168,7 @@ class BaseInputTransport(FrameProcessor):
return vad_state
async def _audio_task_handler(self):
logger.info("_audio_task_handler started")
vad_state: VADState = VADState.QUIET
while True:
try:

View File

@@ -106,6 +106,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
continue
if isinstance(frame, AudioRawFrame):
logger.info("websocket_server")
await self.push_audio_frame(
InputAudioRawFrame(
audio=frame.audio,