Files
pipecat/examples/aws-agentcore/server/server.py
Filipi Fuchter ebbb097e5f Creating server
2025-11-28 06:32:29 -03:00

153 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import argparse
import sys
import uuid
from contextlib import asynccontextmanager
from http import HTTPMethod
from typing import Any, Dict, List, Optional, TypedDict, Union
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request, Response
from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat.transports.smallwebrtc.request_handler import (
IceCandidate,
SmallWebRTCPatchRequest,
SmallWebRTCRequest,
)
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
app = FastAPI()
# Mount the frontend at /
app.mount("/client", SmallWebRTCPrebuiltUI)
# In-memory store of active sessions: session_id -> session info
active_sessions: Dict[str, Dict[str, Any]] = {}
@app.get("/", include_in_schema=False)
async def root_redirect():
return RedirectResponse(url="/client/")
@app.post("/api/offer")
async def offer(request: SmallWebRTCRequest):
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
print(f"Received offer:{request}")
# TODO need to send this to agentcore
answer = None
return answer
@app.patch("/api/offer")
async def ice_candidate(request: SmallWebRTCPatchRequest):
"""Handle WebRTC new ice candidate requests."""
logger.debug(f"Received patch request: {request}")
# TODO need to send this to agentcore
return {"status": "success"}
@app.post("/start")
async def rtvi_start(request: Request):
"""Mimic Pipecat Cloud's /start endpoint."""
class IceServer(TypedDict, total=False):
urls: Union[str, List[str]]
class IceConfig(TypedDict):
iceServers: List[IceServer]
class StartBotResult(TypedDict, total=False):
sessionId: str
iceConfig: Optional[IceConfig]
# Parse the request body
try:
request_data = await request.json()
logger.debug(f"Received request: {request_data}")
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
request_data = {}
# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
session_id = str(uuid.uuid4())
active_sessions[session_id] = request_data
result: StartBotResult = {"sessionId": session_id}
if request_data.get("enableDefaultIceServers"):
result["iceConfig"] = IceConfig(
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
)
return result
@app.api_route(
"/sessions/{session_id}/{path:path}",
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)
async def proxy_request(
session_id: str, path: str, request: Request, background_tasks: BackgroundTasks
):
"""Mimic Pipecat Cloud's proxy."""
active_session = active_sessions.get(session_id)
if active_session is None:
return Response(content="Invalid or not-yet-ready session_id", status_code=404)
if path.endswith("api/offer"):
# Parse the request body and convert to SmallWebRTCRequest
try:
request_data = await request.json()
if request.method == HTTPMethod.POST.value:
webrtc_request = SmallWebRTCRequest(
sdp=request_data["sdp"],
type=request_data["type"],
pc_id=request_data.get("pc_id"),
restart_pc=request_data.get("restart_pc"),
request_data=request_data,
)
return await offer(webrtc_request, background_tasks)
elif request.method == HTTPMethod.PATCH.value:
patch_request = SmallWebRTCPatchRequest(
pc_id=request_data["pc_id"],
candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])],
)
return await ice_candidate(patch_request)
except Exception as e:
logger.error(f"Failed to parse WebRTC request: {e}")
return Response(content="Invalid WebRTC request", status_code=400)
logger.info(f"Received request for path: {path}")
return Response(status_code=200)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield # Run app
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebRTC demo")
parser.add_argument(
"--host", default="localhost", help="Host for HTTP server (default: localhost)"
)
parser.add_argument(
"--port", type=int, default=7860, help="Port for HTTP server (default: 7860)"
)
parser.add_argument("--verbose", "-v", action="count")
args = parser.parse_args()
logger.remove(0)
if args.verbose:
logger.add(sys.stderr, level="TRACE")
else:
logger.add(sys.stderr, level="DEBUG")
uvicorn.run(app, host=args.host, port=args.port)