Compare commits
1 Commits
hush/prere
...
fix/fastap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae8b9f0756 |
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
import os
|
||||
import wave
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -14,14 +13,7 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
LLMTextFrame,
|
||||
OutputAudioRawFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -111,27 +103,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
|
||||
audio_file_path = os.path.join(os.path.dirname(__file__), "assets", "pre-recorded.wav")
|
||||
|
||||
with wave.open(audio_file_path, "rb") as wav_file:
|
||||
llm_text_frame = TextFrame(text="This is a pre-recorded message.")
|
||||
llm_text_frame.skip_tts = True
|
||||
|
||||
audio_data = wav_file.readframes(wav_file.getnframes())
|
||||
output_audio_raw_frame = OutputAudioRawFrame(
|
||||
audio=audio_data, sample_rate=44100, num_channels=1
|
||||
)
|
||||
|
||||
await task.queue_frames(
|
||||
[
|
||||
LLMRunFrame(),
|
||||
LLMFullResponseStartFrame(),
|
||||
llm_text_frame,
|
||||
output_audio_raw_frame,
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
|
||||
Binary file not shown.
@@ -278,6 +278,13 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Main message receiving loop for WebSocket messages."""
|
||||
|
||||
async def trigger_disconnect_if_needed():
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._client.is_closing:
|
||||
await self._client.trigger_client_disconnected()
|
||||
|
||||
try:
|
||||
async for message in self._client.receive():
|
||||
if not self._params.serializer:
|
||||
@@ -294,11 +301,14 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
# Trigger `on_client_disconnected` if the client actually disconnects,
|
||||
# that is, we are not the ones disconnecting.
|
||||
if not self._client.is_closing:
|
||||
await self._client.trigger_client_disconnected()
|
||||
finally:
|
||||
# Use shield to prevent cancellation from stopping the disconnect callback
|
||||
try:
|
||||
await asyncio.shield(trigger_disconnect_if_needed())
|
||||
except asyncio.CancelledError:
|
||||
# Even if we're cancelled, try to trigger the disconnect
|
||||
await trigger_disconnect_if_needed()
|
||||
raise
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
|
||||
Reference in New Issue
Block a user