169 lines
5.2 KiB
Python
169 lines
5.2 KiB
Python
#
|
||
# Copyright (c) 2024–2025, Daily
|
||
#
|
||
# SPDX-License-Identifier: BSD 2-Clause License
|
||
#
|
||
|
||
import argparse
|
||
import json
|
||
import sys
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from http import HTTPMethod
|
||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||
|
||
import boto3
|
||
import uvicorn
|
||
from fastapi import BackgroundTasks, FastAPI, Request, Response
|
||
from fastapi.responses import RedirectResponse
|
||
from loguru import logger
|
||
from pipecat.transports.smallwebrtc.request_handler import (
|
||
IceCandidate,
|
||
SmallWebRTCPatchRequest,
|
||
SmallWebRTCRequest,
|
||
)
|
||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||
|
||
app = FastAPI()
|
||
|
||
# Mount the frontend at /
|
||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||
|
||
# In-memory store of active sessions: session_id -> session info
|
||
active_sessions: Dict[str, Dict[str, Any]] = {}
|
||
|
||
# Initialize Bedrock client
|
||
bedrock = boto3.client("bedrock-agent-runtime")
|
||
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
async def root_redirect():
|
||
return RedirectResponse(url="/client/")
|
||
|
||
|
||
@app.post("/api/offer")
|
||
async def offer(request: Request):
|
||
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
|
||
# Parse JSON body (expecting something like {"input": "..."} )
|
||
data = await request.json()
|
||
print(f"Received offer: {data}")
|
||
|
||
response = bedrock.invoke_agent(
|
||
# TODO: create a custom randon id, maybe based on the pc_id
|
||
agentId="network_test-11111111",
|
||
agentAliasId="Network Test Agent",
|
||
sessionId="user-123456-conversation-11111",
|
||
inputText=json.dumps({"input": data}),
|
||
)
|
||
|
||
result = response["body"].read().decode("utf-8")
|
||
|
||
return {"result": result}
|
||
|
||
|
||
@app.patch("/api/offer")
|
||
async def ice_candidate(request: SmallWebRTCPatchRequest):
|
||
"""Handle WebRTC new ice candidate requests."""
|
||
logger.debug(f"Received patch request: {request}")
|
||
# TODO need to send this to agentcore
|
||
return {"status": "success"}
|
||
|
||
|
||
@app.post("/start")
|
||
async def rtvi_start(request: Request):
|
||
"""Mimic Pipecat Cloud's /start endpoint."""
|
||
|
||
class IceServer(TypedDict, total=False):
|
||
urls: Union[str, List[str]]
|
||
|
||
class IceConfig(TypedDict):
|
||
iceServers: List[IceServer]
|
||
|
||
class StartBotResult(TypedDict, total=False):
|
||
sessionId: str
|
||
iceConfig: Optional[IceConfig]
|
||
|
||
# Parse the request body
|
||
try:
|
||
request_data = await request.json()
|
||
logger.debug(f"Received request: {request_data}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse request body: {e}")
|
||
request_data = {}
|
||
|
||
# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
|
||
session_id = str(uuid.uuid4())
|
||
active_sessions[session_id] = request_data
|
||
|
||
result: StartBotResult = {"sessionId": session_id}
|
||
if request_data.get("enableDefaultIceServers"):
|
||
result["iceConfig"] = IceConfig(
|
||
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
|
||
)
|
||
|
||
return result
|
||
|
||
|
||
@app.api_route(
|
||
"/sessions/{session_id}/{path:path}",
|
||
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
|
||
)
|
||
async def proxy_request(
|
||
session_id: str, path: str, request: Request, background_tasks: BackgroundTasks
|
||
):
|
||
"""Mimic Pipecat Cloud's proxy."""
|
||
active_session = active_sessions.get(session_id)
|
||
if active_session is None:
|
||
return Response(content="Invalid or not-yet-ready session_id", status_code=404)
|
||
|
||
if path.endswith("api/offer"):
|
||
# Parse the request body and convert to SmallWebRTCRequest
|
||
try:
|
||
request_data = await request.json()
|
||
if request.method == HTTPMethod.POST.value:
|
||
webrtc_request = SmallWebRTCRequest(
|
||
sdp=request_data["sdp"],
|
||
type=request_data["type"],
|
||
pc_id=request_data.get("pc_id"),
|
||
restart_pc=request_data.get("restart_pc"),
|
||
request_data=request_data,
|
||
)
|
||
return await offer(webrtc_request, background_tasks)
|
||
elif request.method == HTTPMethod.PATCH.value:
|
||
patch_request = SmallWebRTCPatchRequest(
|
||
pc_id=request_data["pc_id"],
|
||
candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])],
|
||
)
|
||
return await ice_candidate(patch_request)
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse WebRTC request: {e}")
|
||
return Response(content="Invalid WebRTC request", status_code=400)
|
||
|
||
logger.info(f"Received request for path: {path}")
|
||
return Response(status_code=200)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
yield # Run app
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="WebRTC demo")
|
||
parser.add_argument(
|
||
"--host", default="localhost", help="Host for HTTP server (default: localhost)"
|
||
)
|
||
parser.add_argument(
|
||
"--port", type=int, default=7860, help="Port for HTTP server (default: 7860)"
|
||
)
|
||
parser.add_argument("--verbose", "-v", action="count")
|
||
args = parser.parse_args()
|
||
|
||
logger.remove(0)
|
||
if args.verbose:
|
||
logger.add(sys.stderr, level="TRACE")
|
||
else:
|
||
logger.add(sys.stderr, level="DEBUG")
|
||
|
||
uvicorn.run(app, host=args.host, port=args.port)
|