Compare commits

..

1 Commits

Author SHA1 Message Date
James Hush
181cc43724 Background blur example 2025-04-08 14:09:19 +08:00
16 changed files with 88 additions and 491 deletions

View File

@@ -9,14 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added support in `SmallWebRTCTransport` to detect when remote tracks are
muted.
- Added support for image capture from a video stream to the
`SmallWebRTCTransport`.
- Added a new iOS client option to the `SmallWebRTCTransport`
**video-transform** example.
- Added a new iOS client option to the `SmallWebRTCTransport` **video-transform** example.
- Added new processors `ProducerProcessor` and `ConsumerProcessor`. The
producer processor processes frames from the pipeline and decides whether the
@@ -32,17 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
type was incorrectly handled as a codec retransmission.
- Avoid initial video delays.
### Changed
- Updated `GeminiMultimodalLiveLLMService`s default `model` to
`models/gemini-2.0-flash-live-001` and `base_url` to the `v1beta` websocket
URL.
### Fixed
- Updated `daily-python` to 0.17.0 to fix an issue that was preventing to run on
older platforms.
- Fixed an issue in the Azure TTS services where the language was being set
incorrectly.

View File

@@ -1,110 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import aiohttp
import asyncio
import os
import sys
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.services.mcp_run.mcp_run import MCPRun
from pipecat.services.anthropic.llm import AnthropicLLMService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.openai.llm import OpenAILLMService
load_dotenv(override=True)
logger.remove()
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Bot with MCP tools",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
)
llm = AnthropicLLMService(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-7-sonnet-latest")
# llm = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.0-flash-001")
# llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
mcp_run = MCPRun(llm)
tools = mcp_run.register_mcp_tools(llm)
system = """
You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities
in a succinct way. You have access to various tools provided by mcp.run that you can use to help users.
Your output will be converted to audio so don't include special characters in your answers.
Respond to what the user said in a creative and helpful way. Don't overexplain what you are doing.
Just respond with short sentences when you are carrying out tool calls.
"""
messages = [{"role": "system","content": system}]
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User spoken responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses and tool context
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
logger.info("First participant joined: {}", participant["id"])
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([context_aggregator.user().get_context_frame()])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -35,19 +35,9 @@ cd server
python server.py
```
### 2Test with SmallWebRTC Prebuilt UI
### 2Connect Using the Client App
You can quickly test your bot using the `SmallWebRTCPrebuiltUI`:
- Open your browser and navigate to:
👉 http://localhost:7860
- (Or use your custom port, if configured)
### 3⃣ Connect Using a Custom Client App
For client-side setup, refer to the:
- [Typescript Guide](client/typescript/README.md).
- [iOS Guide](client/ios/README.md).
For client-side setup, refer to the [JavaScript Guide](client/typescript/README.md).
## ⚠️ Important Note
Ensure the bot server is running before using any client implementations.

View File

@@ -51,7 +51,6 @@
<div class="bot-container">
<div id="bot-video-container">
<video id="bot-video" autoplay="true" playsinline="true"></video>
<button id="mute-btn">📷</button>
</div>
<audio id="bot-audio" autoplay></audio>
</div>

View File

@@ -10,7 +10,7 @@
"license": "ISC",
"dependencies": {
"@pipecat-ai/client-js": "^0.3.2",
"@pipecat-ai/small-webrtc-transport": "^0.0.2"
"@pipecat-ai/small-webrtc-transport": "^0.0.1"
},
"devDependencies": {
"@types/node": "^22.13.1",
@@ -32,9 +32,9 @@
}
},
"node_modules/@daily-co/daily-js": {
"version": "0.77.0",
"resolved": "https://registry.npmjs.org/@daily-co/daily-js/-/daily-js-0.77.0.tgz",
"integrity": "sha512-icNXKieKAkRR/C5dcPjrCkL1jQGFp5C5WtLHy5uHAdTztm+mo9wlPJuehbWaGOM3TV24mgWHZ/+8jOys1G0I4w==",
"version": "0.73.0",
"resolved": "https://registry.npmjs.org/@daily-co/daily-js/-/daily-js-0.73.0.tgz",
"integrity": "sha512-Wz8c60hgmkx8fcEeDAi4L4J0rbafiihWKyXFyhYoFYPsw2OdChHpA4RYwIB+1enRws5IK+/HdmzFDYLQsB4A6w==",
"license": "BSD-2-Clause",
"dependencies": {
"@babel/runtime": "^7.12.5",
@@ -78,12 +78,12 @@
}
},
"node_modules/@pipecat-ai/small-webrtc-transport": {
"version": "0.0.2",
"resolved": "https://registry.npmjs.org/@pipecat-ai/small-webrtc-transport/-/small-webrtc-transport-0.0.2.tgz",
"integrity": "sha512-9QQBjfAY0yh+ehDt6jX+bX7Ar5GFl+iI6QFS+JPRXeDYCj70bqmUgCYkScbgWzb5uRWZ8ORM+ueVkaLibe+Y4Q==",
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/@pipecat-ai/small-webrtc-transport/-/small-webrtc-transport-0.0.1.tgz",
"integrity": "sha512-WAOI7lT0V7cYOn0+qwUAryGxcOGe+wPVPEPzkR3qsM5GWIZ73spykZnuOndQGycq4UkcXVawCzERfNhpi+Uv7A==",
"license": "BSD-2-Clause",
"dependencies": {
"@daily-co/daily-js": "^0.77.0",
"@daily-co/daily-js": "^0.73.0",
"dequal": "^2.0.3"
},
"peerDependencies": {

View File

@@ -19,6 +19,6 @@
},
"dependencies": {
"@pipecat-ai/client-js": "^0.3.2",
"@pipecat-ai/small-webrtc-transport": "^0.0.2"
"@pipecat-ai/small-webrtc-transport": "^0.0.1"
}
}

View File

@@ -1,13 +1,12 @@
import {
SmallWebRTCTransport
} from "@pipecat-ai/small-webrtc-transport";
import {Participant, RTVIClient, RTVIClientOptions, Transport} from "@pipecat-ai/client-js";
import {Participant, RTVIClient, RTVIClientOptions} from "@pipecat-ai/client-js";
class WebRTCApp {
private declare connectBtn: HTMLButtonElement;
private declare disconnectBtn: HTMLButtonElement;
private declare muteBtn: HTMLButtonElement;
private declare audioInput: HTMLSelectElement;
private declare videoInput: HTMLSelectElement;
@@ -33,10 +32,12 @@ class WebRTCApp {
private initializeRTVIClient(): void {
const transport = new SmallWebRTCTransport();
const RTVIConfig: RTVIClientOptions = {
// need to understand why it is complaining
// @ts-ignore
transport,
params: {
baseUrl: "/api/offer"
},
transport: transport as Transport,
enableMic: true,
enableCam: true,
callbacks: {
@@ -91,7 +92,6 @@ class WebRTCApp {
private setupDOMElements(): void {
this.connectBtn = document.getElementById('connect-btn') as HTMLButtonElement;
this.disconnectBtn = document.getElementById('disconnect-btn') as HTMLButtonElement;
this.muteBtn = document.getElementById('mute-btn') as HTMLButtonElement;
this.audioInput = document.getElementById('audio-input') as HTMLSelectElement;
this.videoInput = document.getElementById('video-input') as HTMLSelectElement;
@@ -118,12 +118,6 @@ class WebRTCApp {
let videoDevice = e.target?.value
this.rtviClient.updateCam(videoDevice)
})
this.muteBtn.addEventListener('click', () => {
let isCamEnabled = this.rtviClient.isCamEnabled
this.rtviClient.enableCam(!isCamEnabled)
this.muteBtn.textContent = isCamEnabled ? '📵' : '📷';
});
}
private log(message: string): void {

View File

@@ -89,7 +89,6 @@ button:disabled {
display: flex;
align-items: center;
justify-content: center;
position: relative;
}
#bot-video-container video {
@@ -98,20 +97,6 @@ button:disabled {
object-fit: cover;
}
#mute-btn {
position: absolute;
bottom: 10px;
right: 10px;
background-color: rgba(0, 0, 0, 0.6);
color: white;
border: none;
border-radius: 20px;
padding: 8px 12px;
cursor: pointer;
font-size: 16px;
z-index: 1;
}
.debug-panel {
background-color: #fff;
border-radius: 8px;

View File

@@ -3,5 +3,4 @@ fastapi[all]
uvicorn
aiortc
opencv-python
pipecat-ai[google,silero,webrtc]
pipecat-ai-small-webrtc-prebuilt
pipecat-ai[google,silero]

View File

@@ -8,8 +8,6 @@ import uvicorn
from bot import run_bot
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI
from fastapi.responses import RedirectResponse
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
@@ -25,14 +23,6 @@ pcs_map: Dict[str, SmallWebRTCConnection] = {}
ice_servers = ["stun:stun.l.google.com:19302"]
# Mount the frontend at /
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)
@app.get("/", include_in_schema=False)
async def root_redirect():
return RedirectResponse(url="/prebuilt/")
@app.post("/api/offer")
async def offer(request: dict, background_tasks: BackgroundTasks):

View File

@@ -1,16 +1,29 @@
import { type PropsWithChildren } from 'react';
import { RTVIClient } from '@pipecat-ai/client-js';
import { DailyTransport } from '@pipecat-ai/daily-transport';
import { RTVIClientProvider } from '@pipecat-ai/client-react';
import { type PropsWithChildren } from "react";
import { RTVIClient } from "@pipecat-ai/client-js";
import { DailyTransport } from "@pipecat-ai/daily-transport";
import { RTVIClientProvider } from "@pipecat-ai/client-react";
const transport = new DailyTransport();
const transport = new DailyTransport({
dailyFactoryOptions: {
inputSettings: {
video: {
processor: {
type: "background-blur",
config: {
strength: 0.8,
},
},
},
},
},
});
const client = new RTVIClient({
transport,
params: {
baseUrl: 'http://localhost:7860',
baseUrl: "http://localhost:7860",
endpoints: {
connect: '/connect',
connect: "/connect",
},
},
enableMic: true,

View File

@@ -47,7 +47,7 @@ canonical = [ "aiofiles~=24.1.0" ]
cartesia = [ "cartesia~=1.4.0", "websockets~=13.1" ]
cerebras = []
deepseek = []
daily = [ "daily-python~=0.17.0" ]
daily = [ "daily-python~=0.16.1" ]
deepgram = [ "deepgram-sdk~=3.8.0" ]
elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.5.9" ]

View File

@@ -164,8 +164,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
self,
*,
api_key: str,
base_url: str = "generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent",
model="models/gemini-2.0-flash-live-001",
base_url="generativelanguage.googleapis.com",
model="models/gemini-2.0-flash-exp",
voice_id: str = "Charon",
start_audio_paused: bool = False,
start_video_paused: bool = False,
@@ -179,8 +179,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
):
super().__init__(base_url=base_url, **kwargs)
self._last_sent_time = 0
self._api_key = api_key
self._base_url = base_url
self.api_key = api_key
self.base_url = base_url
self.set_model_name(model)
self._voice_id = voice_id
@@ -407,8 +407,8 @@ class GeminiMultimodalLiveLLMService(LLMService):
logger.info("Connecting to Gemini service")
try:
logger.info(f"Connecting to wss://{self._base_url}")
uri = f"wss://{self._base_url}?key={self._api_key}"
uri = f"wss://{self.base_url}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
logger.info(f"Connecting to {uri}")
self._websocket = await websockets.connect(uri=uri)
self._receive_task = self.create_task(self._receive_task_handler())
self._transcribe_audio_task = self.create_task(self._transcribe_audio_handler())

View File

@@ -1,141 +0,0 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
import json
from typing import Any, Dict, List, Mapping, Optional, Union
from loguru import logger
from pipecat.services.llm_service import LLMService
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from mcp_run import Client
try:
from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use mcp.run, you need to `pip install pipecat-ai[mcp_run]`. "
+ "Also, set `MCP_RUN_SESSION_ID` environment variable."
)
raise Exception(f"Missing module: {e}")
class MCPRun(Client):
def __init__(
self,
llm: LLMService,
mcp_run_session_id: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self._client = Client()
self._mcp_run_session_id = mcp_run_session_id or os.getenv("MCP_RUN_SESSION_ID")
def convert_mcp_schema_to_pipecat(self, tool_name: str, tool_schema: dict[str, any]) -> FunctionSchema:
"""Convert an mcp.run tool schema to Pipecat's FunctionSchema format.
Args:
tool_name: The name of the tool
tool_schema: The mcp.run tool schema
Returns:
A FunctionSchema instance
"""
logger.debug(f"Converting schema for tool '{tool_name}'")
logger.debug(f"Original schema: {json.dumps(tool_schema, indent=2)}")
# Extract properties and required fields from the mcp.run schema
properties = tool_schema["input_schema"].get("properties", {})
required = tool_schema["input_schema"].get("required", [])
schema = FunctionSchema(
name=tool_name,
description=tool_schema["description"],
properties=properties,
required=required
)
logger.debug(f"Converted schema: {json.dumps(schema.to_default_dict(), indent=2)}")
return schema
def register_mcp_tools(self, llm) -> ToolsSchema:
"""Register all available mcp.run tools with the LLM service.
Args:
llm: The Pipecat LLM service to register tools with
Returns:
A ToolsSchema containing all registered tools
"""
async def mcp_tool_wrapper(function_name: str, tool_call_id: str, arguments: dict[str, any],
llm: any, context: any, result_callback: any) -> None:
"""Wrapper for mcp.run tool calls to match Pipecat's function call interface.
"""
logger.debug(f"Executing tool '{function_name}' with call ID: {tool_call_id}")
logger.debug(f"Tool arguments: {json.dumps(arguments, indent=2)}")
try:
# Call the mcp.run tool
logger.debug(f"Calling mcp.run tool '{function_name}'")
results = self._client.call_tool(function_name, params=arguments)
# Combine all content into a single response
response = ""
for i, content in enumerate(results.content):
logger.debug(f"Tool response chunk {i}: {content.text}")
response += content.text
logger.info(f"Tool '{function_name}' completed successfully")
logger.info(f"Final response: {response}")
# Send result back through callback
await result_callback(response)
except Exception as e:
error_msg = f"Error calling mcp.run tool {function_name}: {str(e)}"
logger.error(error_msg)
logger.exception("Full exception details:")
await result_callback(error_msg)
logger.debug("Starting registration of mcp.run tools")
tool_schemas: List[FunctionSchema] = []
# Get all available tools from mcp.run
available_tools = self._client.tools
logger.debug(f"Found {len(available_tools)} available tools")
for tool_name, tool in available_tools.items():
logger.debug(f"Processing tool: {tool_name}")
logger.debug(f"Tool description: {tool.description}")
try:
# Convert the schema
function_schema = self.convert_mcp_schema_to_pipecat(tool_name, {
"description": tool.description,
"input_schema": tool.input_schema
})
# Register the wrapped function
logger.debug(f"Registering function handler for '{tool_name}'")
llm.register_function(tool_name, mcp_tool_wrapper)
# Add to our list of schemas
tool_schemas.append(function_schema)
logger.debug(f"Successfully registered tool '{tool_name}'")
except Exception as e:
logger.error(f"Failed to register tool '{tool_name}': {str(e)}")
logger.exception("Full exception details:")
continue
logger.info(f"Completed registration of {len(tool_schemas)} tools")
tools_schema = ToolsSchema(standard_tools=tool_schemas)
return tools_schema

View File

@@ -17,17 +17,13 @@ from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
InputImageRawFrame,
OutputImageRawFrame,
StartFrame,
TransportMessageFrame,
TransportMessageUrgentFrame,
UserImageRawFrame,
UserImageRequestFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -63,7 +59,9 @@ class RawAudioTrack(AudioStreamTrack):
self._chunk_queue = deque()
def add_audio_bytes(self, audio_bytes: bytes):
"""Adds bytes to the audio buffer and returns a Future that completes when the data is processed."""
"""
Adds bytes to the audio buffer and returns a Future that completes when the data is processed.
"""
if len(audio_bytes) % self._bytes_per_10ms != 0:
raise ValueError("Audio bytes must be a multiple of 10ms size.")
future = asyncio.get_running_loop().create_future()
@@ -78,7 +76,9 @@ class RawAudioTrack(AudioStreamTrack):
return future
async def recv(self):
"""Returns the next audio frame, generating silence if needed."""
"""
Returns the next audio frame, generating silence if needed.
"""
# Compute required wait time for synchronization
if self._timestamp > 0:
wait = self._start + (self._timestamp / self._sample_rate) - time.time()
@@ -179,7 +179,8 @@ class SmallWebRTCClient:
await self._handle_app_message(message)
def _convert_frame(self, frame_array: np.ndarray, format_name: str) -> np.ndarray:
"""Convert a given frame to RGB format based on the input format.
"""
Convert a given frame to RGB format based on the input format.
Args:
frame_array (np.ndarray): The input frame.
@@ -202,7 +203,8 @@ class SmallWebRTCClient:
return cv2.cvtColor(frame_array, conversion_code)
async def read_video_frame(self):
"""Reads a video frame from the given MediaStreamTrack, converts it to RGB,
"""
Reads a video frame from the given MediaStreamTrack, converts it to RGB,
and creates an InputImageRawFrame.
"""
while True:
@@ -240,7 +242,9 @@ class SmallWebRTCClient:
yield image_frame
async def read_audio_frame(self):
"""Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame."""
"""
Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame.
"""
while True:
if self._audio_input_track is None:
await asyncio.sleep(0.01)
@@ -375,13 +379,6 @@ class SmallWebRTCInputTransport(BaseInputTransport):
self._params = params
self._receive_audio_task = None
self._receive_video_task = None
self._image_requests = {}
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserImageRequestFrame):
await self.request_participant_image(frame)
async def start(self, frame: StartFrame):
await super().start(frame)
@@ -427,22 +424,6 @@ class SmallWebRTCInputTransport(BaseInputTransport):
if video_frame:
await self.push_frame(video_frame)
# Check if there are any pending image requests and create UserImageRawFrame
if self._image_requests:
for req_id, request_frame in list(self._image_requests.items()):
# Create UserImageRawFrame using the current video frame
image_frame = UserImageRawFrame(
user_id=request_frame.user_id,
request=request_frame,
image=video_frame.image,
size=video_frame.size,
format=video_frame.format,
)
# Push the frame to the pipeline
await self.push_frame(image_frame)
# Remove from pending requests
del self._image_requests[req_id]
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
@@ -451,24 +432,6 @@ class SmallWebRTCInputTransport(BaseInputTransport):
frame = TransportMessageUrgentFrame(message=message)
await self.push_frame(frame)
# Add this method similar to DailyInputTransport.request_participant_image
async def request_participant_image(self, frame: UserImageRequestFrame):
"""Requests an image frame from the participant's video stream.
When a UserImageRequestFrame is received, this method will store the request
and the next video frame received will be converted to a UserImageRawFrame.
"""
logger.debug(f"Requesting image from participant: {frame.user_id}")
# Store the request
request_id = f"{frame.function_name}:{frame.tool_call_id}"
self._image_requests[request_id] = frame
# If we're not already receiving video, try to get a frame now
if not self._receive_video_task and self._params.camera_in_enabled:
# Start video reception if it's not already running
self._receive_video_task = self.create_task(self._receive_video())
class SmallWebRTCOutputTransport(BaseOutputTransport):
def __init__(

View File

@@ -7,22 +7,15 @@
import asyncio
import json
import time
from typing import Any, Literal, Optional, Union
from enum import Enum
from typing import Any, Optional
from av.frame import Frame
from loguru import logger
from pydantic import BaseModel, TypeAdapter
from pipecat.utils.base_object import BaseObject
try:
from aiortc import (
MediaStreamTrack,
RTCConfiguration,
RTCIceServer,
RTCPeerConnection,
RTCSessionDescription,
)
from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
from aiortc.rtcrtpreceiver import RemoteStreamTrack
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
@@ -30,57 +23,10 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
SIGNALLING_TYPE = "signalling"
AUDIO_TRANSCEIVER_INDEX = 0
VIDEO_TRANSCEIVER_INDEX = 1
class TrackStatusMessage(BaseModel):
type: Literal["trackStatus"]
receiver_index: int
enabled: bool
class RenegotiateMessage(BaseModel):
type: Literal["renegotiate"] = "renegotiate"
class SignallingMessage:
Inbound = Union[TrackStatusMessage] # in case we need to add new messages in the future
outbound = Union[RenegotiateMessage]
class SmallWebRTCTrack:
def __init__(self, track: MediaStreamTrack):
self._track = track
self._enabled = True
def set_enabled(self, enabled: bool) -> None:
self._enabled = enabled
def is_enabled(self) -> bool:
return self._enabled
async def discard_old_frames(self):
remote_track = self._track
if isinstance(remote_track, RemoteStreamTrack):
if not hasattr(remote_track, "_queue") or not isinstance(
remote_track._queue, asyncio.Queue
):
print("Warning: _queue does not exist or has changed in aiortc.")
return
logger.debug("Discarding old frames")
while not remote_track._queue.empty():
remote_track._queue.get_nowait() # Remove the oldest frame
remote_track._queue.task_done()
async def recv(self) -> Optional[Frame]:
if not self._enabled:
return None
return await self._track.recv()
def __getattr__(self, name):
# Forward other attribute/method calls to the underlying track
return getattr(self._track, name)
class SignallingMessage(Enum):
RENEGOTIATE = "renegotiate"
class SmallWebRTCConnection(BaseObject):
@@ -91,12 +37,6 @@ class SmallWebRTCConnection(BaseObject):
else:
self.ice_servers = []
self._connect_invoked = False
self._track_map = {}
self._track_getters = {
AUDIO_TRANSCEIVER_INDEX: self.audio_input_track,
VIDEO_TRANSCEIVER_INDEX: self.video_input_track,
}
self._initialize()
# Register supported handlers. The user will only be able to register
@@ -128,6 +68,7 @@ class SmallWebRTCConnection(BaseObject):
self._pc = RTCPeerConnection(rtc_config)
self._pc_id = self.name
self._setup_listeners()
self._tracks = set()
self._data_channel = None
self._renegotiation_in_progress = False
self._last_received_time = None
@@ -155,10 +96,7 @@ class SmallWebRTCConnection(BaseObject):
self._last_received_time = time.time()
else:
json_message = json.loads(message)
if json_message["type"] == SIGNALLING_TYPE and json_message.get("message"):
self._handle_signalling_message(json_message["message"])
else:
await self._call_event_handler("app-message", json_message)
await self._call_event_handler("app-message", json_message)
except Exception as e:
logger.exception(f"Error parsing JSON message {message}, {e}")
@@ -183,11 +121,13 @@ class SmallWebRTCConnection(BaseObject):
@self._pc.on("track")
async def on_track(track):
logger.debug(f"Track {track.kind} received")
self._tracks.add(track)
await self._call_event_handler("track-started", track)
@track.on("ended")
async def on_ended():
logger.debug(f"Track {track.kind} ended")
self._tracks.discard(track)
await self._call_event_handler("track-ended", track)
async def _create_answer(self, sdp: str, type: str):
@@ -208,6 +148,17 @@ class SmallWebRTCConnection(BaseObject):
async def initialize(self, sdp: str, type: str):
await self._create_answer(sdp, type)
async def discard_old_frames(self, remote_track: RemoteStreamTrack):
if not hasattr(remote_track, "_queue") or not isinstance(
remote_track._queue, asyncio.Queue
):
print("Warning: _queue does not exist or has changed in aiortc.")
return
logger.debug("Discarding old frames")
while not remote_track._queue.empty():
remote_track._queue.get_nowait() # Remove the oldest frame
remote_track._queue.task_done()
async def connect(self):
self._connect_invoked = True
# If we already connected, trigger again the connected event
@@ -215,7 +166,9 @@ class SmallWebRTCConnection(BaseObject):
await self._call_event_handler("connected")
# We are renegotiating here, because likely we have loose the first video frames
# and aiortc does not handle that pretty well.
await self.video_input_track().discard_old_frames()
remove_video_track = self.video_input_track()
if isinstance(remove_video_track, RemoteStreamTrack):
await self.discard_old_frames(remove_video_track)
self.ask_to_renegotiate()
async def renegotiate(self, sdp: str, type: str, restart_pc: bool = False):
@@ -275,7 +228,6 @@ class SmallWebRTCConnection(BaseObject):
if self._pc:
await self._pc.close()
self._message_queue.clear()
self._track_map = {}
def get_answer(self):
if not self._answer:
@@ -315,38 +267,29 @@ class SmallWebRTCConnection(BaseObject):
return (time.time() - self._last_received_time) < 3
def audio_input_track(self):
if self._track_map.get(AUDIO_TRANSCEIVER_INDEX):
return self._track_map[AUDIO_TRANSCEIVER_INDEX]
# Transceivers always appear in creation-order for both peers
# For now we are only considering that we are going to have 02 transceivers,
# one for audio and one for video
transceivers = self._pc.getTransceivers()
if len(transceivers) == 0 or not transceivers[AUDIO_TRANSCEIVER_INDEX].receiver:
if len(transceivers) == 0 or not transceivers[0].receiver:
logger.warning("No audio transceiver is available")
return None
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
audio_track = SmallWebRTCTrack(track) if track else None
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
return audio_track
return transceivers[0].receiver.track
def video_input_track(self):
if self._track_map.get(VIDEO_TRANSCEIVER_INDEX):
return self._track_map[VIDEO_TRANSCEIVER_INDEX]
# Transceivers always appear in creation-order for both peers
# For now we are only considering that we are going to have 02 transceivers,
# one for audio and one for video
transceivers = self._pc.getTransceivers()
if len(transceivers) <= 1 or not transceivers[VIDEO_TRANSCEIVER_INDEX].receiver:
if len(transceivers) <= 1 or not transceivers[1].receiver:
logger.warning("No video transceiver is available")
return None
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
video_track = SmallWebRTCTrack(track) if track else None
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
return video_track
return transceivers[1].receiver.track
def tracks(self):
return self._tracks
def send_app_message(self, message: Any):
json_message = json.dumps(message)
@@ -362,17 +305,5 @@ class SmallWebRTCConnection(BaseObject):
self._renegotiation_in_progress = True
self.send_app_message(
{"type": SIGNALLING_TYPE, "message": RenegotiateMessage().model_dump()}
{"type": SIGNALLING_TYPE, "message": SignallingMessage.RENEGOTIATE.value}
)
def _handle_signalling_message(self, message):
logger.debug(f"Signalling message received: {message}")
inbound_adapter = TypeAdapter(SignallingMessage.Inbound)
signalling_message = inbound_adapter.validate_python(message)
match signalling_message:
case TrackStatusMessage():
track = (
self._track_getters.get(signalling_message.receiver_index) or (lambda: None)
)()
if track:
track.set_enabled(signalling_message.enabled)