Merge pull request #1539 from pipecat-ai/small_wbertc_mute_state
SmallWebRTC mute state
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added support in `SmallWebRTCTransport` to detect when remote tracks are
|
||||
muted.
|
||||
|
||||
- Added support for image capture from a video stream to the
|
||||
`SmallWebRTCTransport`.
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
<div class="bot-container">
|
||||
<div id="bot-video-container">
|
||||
<video id="bot-video" autoplay="true" playsinline="true"></video>
|
||||
<button id="mute-btn">📷</button>
|
||||
</div>
|
||||
<audio id="bot-audio" autoplay></audio>
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@pipecat-ai/client-js": "^0.3.2",
|
||||
"@pipecat-ai/small-webrtc-transport": "^0.0.1"
|
||||
"@pipecat-ai/small-webrtc-transport": "^0.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.13.1",
|
||||
@@ -32,9 +32,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@daily-co/daily-js": {
|
||||
"version": "0.73.0",
|
||||
"resolved": "https://registry.npmjs.org/@daily-co/daily-js/-/daily-js-0.73.0.tgz",
|
||||
"integrity": "sha512-Wz8c60hgmkx8fcEeDAi4L4J0rbafiihWKyXFyhYoFYPsw2OdChHpA4RYwIB+1enRws5IK+/HdmzFDYLQsB4A6w==",
|
||||
"version": "0.77.0",
|
||||
"resolved": "https://registry.npmjs.org/@daily-co/daily-js/-/daily-js-0.77.0.tgz",
|
||||
"integrity": "sha512-icNXKieKAkRR/C5dcPjrCkL1jQGFp5C5WtLHy5uHAdTztm+mo9wlPJuehbWaGOM3TV24mgWHZ/+8jOys1G0I4w==",
|
||||
"license": "BSD-2-Clause",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.12.5",
|
||||
@@ -78,12 +78,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@pipecat-ai/small-webrtc-transport": {
|
||||
"version": "0.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@pipecat-ai/small-webrtc-transport/-/small-webrtc-transport-0.0.1.tgz",
|
||||
"integrity": "sha512-WAOI7lT0V7cYOn0+qwUAryGxcOGe+wPVPEPzkR3qsM5GWIZ73spykZnuOndQGycq4UkcXVawCzERfNhpi+Uv7A==",
|
||||
"version": "0.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@pipecat-ai/small-webrtc-transport/-/small-webrtc-transport-0.0.2.tgz",
|
||||
"integrity": "sha512-9QQBjfAY0yh+ehDt6jX+bX7Ar5GFl+iI6QFS+JPRXeDYCj70bqmUgCYkScbgWzb5uRWZ8ORM+ueVkaLibe+Y4Q==",
|
||||
"license": "BSD-2-Clause",
|
||||
"dependencies": {
|
||||
"@daily-co/daily-js": "^0.73.0",
|
||||
"@daily-co/daily-js": "^0.77.0",
|
||||
"dequal": "^2.0.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
@@ -19,6 +19,6 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@pipecat-ai/client-js": "^0.3.2",
|
||||
"@pipecat-ai/small-webrtc-transport": "^0.0.1"
|
||||
"@pipecat-ai/small-webrtc-transport": "^0.0.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import {
|
||||
SmallWebRTCTransport
|
||||
} from "@pipecat-ai/small-webrtc-transport";
|
||||
import {Participant, RTVIClient, RTVIClientOptions} from "@pipecat-ai/client-js";
|
||||
import {Participant, RTVIClient, RTVIClientOptions, Transport} from "@pipecat-ai/client-js";
|
||||
|
||||
class WebRTCApp {
|
||||
|
||||
private declare connectBtn: HTMLButtonElement;
|
||||
private declare disconnectBtn: HTMLButtonElement;
|
||||
private declare muteBtn: HTMLButtonElement;
|
||||
|
||||
private declare audioInput: HTMLSelectElement;
|
||||
private declare videoInput: HTMLSelectElement;
|
||||
@@ -32,12 +33,10 @@ class WebRTCApp {
|
||||
private initializeRTVIClient(): void {
|
||||
const transport = new SmallWebRTCTransport();
|
||||
const RTVIConfig: RTVIClientOptions = {
|
||||
// need to understand why it is complaining
|
||||
// @ts-ignore
|
||||
transport,
|
||||
params: {
|
||||
baseUrl: "/api/offer"
|
||||
},
|
||||
transport: transport as Transport,
|
||||
enableMic: true,
|
||||
enableCam: true,
|
||||
callbacks: {
|
||||
@@ -92,6 +91,7 @@ class WebRTCApp {
|
||||
private setupDOMElements(): void {
|
||||
this.connectBtn = document.getElementById('connect-btn') as HTMLButtonElement;
|
||||
this.disconnectBtn = document.getElementById('disconnect-btn') as HTMLButtonElement;
|
||||
this.muteBtn = document.getElementById('mute-btn') as HTMLButtonElement;
|
||||
|
||||
this.audioInput = document.getElementById('audio-input') as HTMLSelectElement;
|
||||
this.videoInput = document.getElementById('video-input') as HTMLSelectElement;
|
||||
@@ -118,6 +118,12 @@ class WebRTCApp {
|
||||
let videoDevice = e.target?.value
|
||||
this.rtviClient.updateCam(videoDevice)
|
||||
})
|
||||
this.muteBtn.addEventListener('click', () => {
|
||||
let isCamEnabled = this.rtviClient.isCamEnabled
|
||||
this.rtviClient.enableCam(!isCamEnabled)
|
||||
this.muteBtn.textContent = isCamEnabled ? '📵' : '📷';
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
private log(message: string): void {
|
||||
|
||||
@@ -89,6 +89,7 @@ button:disabled {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#bot-video-container video {
|
||||
@@ -97,6 +98,20 @@ button:disabled {
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
#mute-btn {
|
||||
position: absolute;
|
||||
bottom: 10px;
|
||||
right: 10px;
|
||||
background-color: rgba(0, 0, 0, 0.6);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 20px;
|
||||
padding: 8px 12px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.debug-panel {
|
||||
background-color: #fff;
|
||||
border-radius: 8px;
|
||||
|
||||
@@ -7,15 +7,22 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from av.frame import Frame
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from pipecat.utils.base_object import BaseObject
|
||||
|
||||
try:
|
||||
from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc import (
|
||||
MediaStreamTrack,
|
||||
RTCConfiguration,
|
||||
RTCIceServer,
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
)
|
||||
from aiortc.rtcrtpreceiver import RemoteStreamTrack
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -23,10 +30,57 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
SIGNALLING_TYPE = "signalling"
|
||||
AUDIO_TRANSCEIVER_INDEX = 0
|
||||
VIDEO_TRANSCEIVER_INDEX = 1
|
||||
|
||||
|
||||
class SignallingMessage(Enum):
|
||||
RENEGOTIATE = "renegotiate"
|
||||
class TrackStatusMessage(BaseModel):
|
||||
type: Literal["trackStatus"]
|
||||
receiver_index: int
|
||||
enabled: bool
|
||||
|
||||
|
||||
class RenegotiateMessage(BaseModel):
|
||||
type: Literal["renegotiate"] = "renegotiate"
|
||||
|
||||
|
||||
class SignallingMessage:
|
||||
Inbound = Union[TrackStatusMessage] # in case we need to add new messages in the future
|
||||
outbound = Union[RenegotiateMessage]
|
||||
|
||||
|
||||
class SmallWebRTCTrack:
|
||||
def __init__(self, track: MediaStreamTrack):
|
||||
self._track = track
|
||||
self._enabled = True
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self._enabled = enabled
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
async def discard_old_frames(self):
|
||||
remote_track = self._track
|
||||
if isinstance(remote_track, RemoteStreamTrack):
|
||||
if not hasattr(remote_track, "_queue") or not isinstance(
|
||||
remote_track._queue, asyncio.Queue
|
||||
):
|
||||
print("Warning: _queue does not exist or has changed in aiortc.")
|
||||
return
|
||||
logger.debug("Discarding old frames")
|
||||
while not remote_track._queue.empty():
|
||||
remote_track._queue.get_nowait() # Remove the oldest frame
|
||||
remote_track._queue.task_done()
|
||||
|
||||
async def recv(self) -> Optional[Frame]:
|
||||
if not self._enabled:
|
||||
return None
|
||||
return await self._track.recv()
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Forward other attribute/method calls to the underlying track
|
||||
return getattr(self._track, name)
|
||||
|
||||
|
||||
class SmallWebRTCConnection(BaseObject):
|
||||
@@ -37,6 +91,12 @@ class SmallWebRTCConnection(BaseObject):
|
||||
else:
|
||||
self.ice_servers = []
|
||||
self._connect_invoked = False
|
||||
self._track_map = {}
|
||||
self._track_getters = {
|
||||
AUDIO_TRANSCEIVER_INDEX: self.audio_input_track,
|
||||
VIDEO_TRANSCEIVER_INDEX: self.video_input_track,
|
||||
}
|
||||
|
||||
self._initialize()
|
||||
|
||||
# Register supported handlers. The user will only be able to register
|
||||
@@ -68,7 +128,6 @@ class SmallWebRTCConnection(BaseObject):
|
||||
self._pc = RTCPeerConnection(rtc_config)
|
||||
self._pc_id = self.name
|
||||
self._setup_listeners()
|
||||
self._tracks = set()
|
||||
self._data_channel = None
|
||||
self._renegotiation_in_progress = False
|
||||
self._last_received_time = None
|
||||
@@ -96,7 +155,10 @@ class SmallWebRTCConnection(BaseObject):
|
||||
self._last_received_time = time.time()
|
||||
else:
|
||||
json_message = json.loads(message)
|
||||
await self._call_event_handler("app-message", json_message)
|
||||
if json_message["type"] == SIGNALLING_TYPE and json_message.get("message"):
|
||||
self._handle_signalling_message(json_message["message"])
|
||||
else:
|
||||
await self._call_event_handler("app-message", json_message)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing JSON message {message}, {e}")
|
||||
|
||||
@@ -121,13 +183,11 @@ class SmallWebRTCConnection(BaseObject):
|
||||
@self._pc.on("track")
|
||||
async def on_track(track):
|
||||
logger.debug(f"Track {track.kind} received")
|
||||
self._tracks.add(track)
|
||||
await self._call_event_handler("track-started", track)
|
||||
|
||||
@track.on("ended")
|
||||
async def on_ended():
|
||||
logger.debug(f"Track {track.kind} ended")
|
||||
self._tracks.discard(track)
|
||||
await self._call_event_handler("track-ended", track)
|
||||
|
||||
async def _create_answer(self, sdp: str, type: str):
|
||||
@@ -148,17 +208,6 @@ class SmallWebRTCConnection(BaseObject):
|
||||
async def initialize(self, sdp: str, type: str):
|
||||
await self._create_answer(sdp, type)
|
||||
|
||||
async def discard_old_frames(self, remote_track: RemoteStreamTrack):
|
||||
if not hasattr(remote_track, "_queue") or not isinstance(
|
||||
remote_track._queue, asyncio.Queue
|
||||
):
|
||||
print("Warning: _queue does not exist or has changed in aiortc.")
|
||||
return
|
||||
logger.debug("Discarding old frames")
|
||||
while not remote_track._queue.empty():
|
||||
remote_track._queue.get_nowait() # Remove the oldest frame
|
||||
remote_track._queue.task_done()
|
||||
|
||||
async def connect(self):
|
||||
self._connect_invoked = True
|
||||
# If we already connected, trigger again the connected event
|
||||
@@ -166,9 +215,7 @@ class SmallWebRTCConnection(BaseObject):
|
||||
await self._call_event_handler("connected")
|
||||
# We are renegotiating here, because likely we have loose the first video frames
|
||||
# and aiortc does not handle that pretty well.
|
||||
remove_video_track = self.video_input_track()
|
||||
if isinstance(remove_video_track, RemoteStreamTrack):
|
||||
await self.discard_old_frames(remove_video_track)
|
||||
await self.video_input_track().discard_old_frames()
|
||||
self.ask_to_renegotiate()
|
||||
|
||||
async def renegotiate(self, sdp: str, type: str, restart_pc: bool = False):
|
||||
@@ -228,6 +275,7 @@ class SmallWebRTCConnection(BaseObject):
|
||||
if self._pc:
|
||||
await self._pc.close()
|
||||
self._message_queue.clear()
|
||||
self._track_map = {}
|
||||
|
||||
def get_answer(self):
|
||||
if not self._answer:
|
||||
@@ -267,29 +315,38 @@ class SmallWebRTCConnection(BaseObject):
|
||||
return (time.time() - self._last_received_time) < 3
|
||||
|
||||
def audio_input_track(self):
|
||||
if self._track_map.get(AUDIO_TRANSCEIVER_INDEX):
|
||||
return self._track_map[AUDIO_TRANSCEIVER_INDEX]
|
||||
|
||||
# Transceivers always appear in creation-order for both peers
|
||||
# For now we are only considering that we are going to have 02 transceivers,
|
||||
# one for audio and one for video
|
||||
transceivers = self._pc.getTransceivers()
|
||||
if len(transceivers) == 0 or not transceivers[0].receiver:
|
||||
if len(transceivers) == 0 or not transceivers[AUDIO_TRANSCEIVER_INDEX].receiver:
|
||||
logger.warning("No audio transceiver is available")
|
||||
return None
|
||||
|
||||
return transceivers[0].receiver.track
|
||||
track = transceivers[AUDIO_TRANSCEIVER_INDEX].receiver.track
|
||||
audio_track = SmallWebRTCTrack(track) if track else None
|
||||
self._track_map[AUDIO_TRANSCEIVER_INDEX] = audio_track
|
||||
return audio_track
|
||||
|
||||
def video_input_track(self):
|
||||
if self._track_map.get(VIDEO_TRANSCEIVER_INDEX):
|
||||
return self._track_map[VIDEO_TRANSCEIVER_INDEX]
|
||||
|
||||
# Transceivers always appear in creation-order for both peers
|
||||
# For now we are only considering that we are going to have 02 transceivers,
|
||||
# one for audio and one for video
|
||||
transceivers = self._pc.getTransceivers()
|
||||
if len(transceivers) <= 1 or not transceivers[1].receiver:
|
||||
if len(transceivers) <= 1 or not transceivers[VIDEO_TRANSCEIVER_INDEX].receiver:
|
||||
logger.warning("No video transceiver is available")
|
||||
return None
|
||||
|
||||
return transceivers[1].receiver.track
|
||||
|
||||
def tracks(self):
|
||||
return self._tracks
|
||||
track = transceivers[VIDEO_TRANSCEIVER_INDEX].receiver.track
|
||||
video_track = SmallWebRTCTrack(track) if track else None
|
||||
self._track_map[VIDEO_TRANSCEIVER_INDEX] = video_track
|
||||
return video_track
|
||||
|
||||
def send_app_message(self, message: Any):
|
||||
json_message = json.dumps(message)
|
||||
@@ -305,5 +362,17 @@ class SmallWebRTCConnection(BaseObject):
|
||||
|
||||
self._renegotiation_in_progress = True
|
||||
self.send_app_message(
|
||||
{"type": SIGNALLING_TYPE, "message": SignallingMessage.RENEGOTIATE.value}
|
||||
{"type": SIGNALLING_TYPE, "message": RenegotiateMessage().model_dump()}
|
||||
)
|
||||
|
||||
def _handle_signalling_message(self, message):
|
||||
logger.debug(f"Signalling message received: {message}")
|
||||
inbound_adapter = TypeAdapter(SignallingMessage.Inbound)
|
||||
signalling_message = inbound_adapter.validate_python(message)
|
||||
match signalling_message:
|
||||
case TrackStatusMessage():
|
||||
track = (
|
||||
self._track_getters.get(signalling_message.receiver_index) or (lambda: None)
|
||||
)()
|
||||
if track:
|
||||
track.set_enabled(signalling_message.enabled)
|
||||
|
||||
Reference in New Issue
Block a user