PR comments

Also satisfy some Pyright complaints and update default model
This commit is contained in:
Mike Depinet
2025-12-12 15:03:31 -08:00
parent ccdf83800b
commit 2e4fa3f8db
3 changed files with 83 additions and 72 deletions

View File

@@ -36,17 +36,14 @@ transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_enabled=False,
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_enabled=False,
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
vad_enabled=False,
),
}

View File

@@ -1 +0,0 @@
from .llm import UltravoxRealtimeLLMService

View File

@@ -26,8 +26,6 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.utils import create_stream_resampler
from pipecat.frames.frames import (
AggregationType,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Frame,
@@ -44,6 +42,7 @@ from pipecat.frames.frames import (
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
UserAudioRawFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import (
@@ -115,7 +114,7 @@ class OneShotInputParams(BaseModel):
api_key: str
system_prompt: Optional[str] = None
temperature: float = Field(default=0.0, ge=0.0, le=1.0)
model: str = "fixie-ai/ultravox"
model: Optional[str] = None
voice: Optional[uuid.UUID] = None
metadata: Dict[str, str] = Field(default_factory=dict)
max_duration: datetime.timedelta = Field(
@@ -177,6 +176,7 @@ class UltravoxRealtimeLLMService(LLMService):
self._receive_task: Optional[asyncio.Task] = None
self._disconnecting = False
self._bot_responding: Literal[None, "text", "voice"] = None
self._last_user_id: Optional[str] = None
self._sample_rate = 48000
self._resampler = create_stream_resampler()
@@ -193,61 +193,72 @@ class UltravoxRealtimeLLMService(LLMService):
"""
await super().start(frame)
match self._params:
case JoinUrlInputParams():
join_url = self._params.join_url
case AgentInputParams():
request_body = {
"templateContext": self._params.template_context,
"metadata": self._params.metadata,
"maxDuration": f"{self._params.max_duration.total_seconds():3f}s",
"medium": {
"serverWebSocket": {
"inputSampleRate": self._sample_rate,
}
},
} | self._params.extra
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api.ultravox.ai/api/agents/{self._params.agent_id}/calls",
headers={"X-Api-Key": self._params.api_key},
json=request_body,
) as response:
if response.status != 201:
error_text = await response.text()
raise Exception(f"Ultravox API error {response.status}: {error_text}")
join_url = (await response.json())["joinUrl"]
case OneShotInputParams():
request_body = {
"systemPrompt": self._params.system_prompt,
"temperature": self._params.temperature,
"model": self._params.model,
"voice": str(self._params.voice) if self._params.voice else None,
"metadata": self._params.metadata,
"maxDuration": f"{self._params.max_duration.total_seconds():3f}s",
"selectedTools": self._to_selected_tools(self._selected_tools)
if self._selected_tools
else [],
"medium": {
"serverWebSocket": {
"inputSampleRate": self._sample_rate,
}
},
} | self._params.extra
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.ultravox.ai/api/calls",
headers={"X-Api-Key": self._params.api_key},
json=request_body,
) as response:
if response.status != 201:
error_text = await response.text()
raise Exception(f"Ultravox API error {response.status}: {error_text}")
join_url = (await response.json())["joinUrl"]
try:
match self._params:
case JoinUrlInputParams():
join_url = self._params.join_url
case AgentInputParams():
join_url = await self._start_agent_call(self._params)
case OneShotInputParams():
join_url = await self._start_one_shot_call(self._params)
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
self._socket = await websocket_client.connect(join_url)
self._receive_task = self.create_task(self._receive_messages())
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
self._socket = await websocket_client.connect(join_url)
self._receive_task = self.create_task(self._receive_messages())
except Exception as e:
await self.push_error("Failed to connect to Ultravox", e, fatal=True)
async def _start_agent_call(self, params: AgentInputParams) -> str:
request_body = {
"templateContext": params.template_context,
"metadata": params.metadata,
"medium": {
"serverWebSocket": {
"inputSampleRate": self._sample_rate,
}
},
}
if params.max_duration:
request_body["maxDuration"] = f"{params.max_duration.total_seconds():3f}s"
request_body = request_body | params.extra
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api.ultravox.ai/api/agents/{params.agent_id}/calls",
headers={"X-Api-Key": params.api_key},
json=request_body,
) as response:
if response.status != 201:
error_text = await response.text()
raise Exception(f"Ultravox API error {response.status}: {error_text}")
return (await response.json())["joinUrl"]
async def _start_one_shot_call(self, params: OneShotInputParams) -> str:
request_body = {
"systemPrompt": params.system_prompt,
"temperature": params.temperature,
"model": params.model,
"voice": str(params.voice) if params.voice else None,
"metadata": params.metadata,
"maxDuration": f"{params.max_duration.total_seconds():3f}s",
"selectedTools": self._to_selected_tools(self._selected_tools)
if self._selected_tools
else [],
"medium": {
"serverWebSocket": {
"inputSampleRate": self._sample_rate,
}
},
} | params.extra
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.ultravox.ai/api/calls",
headers={"X-Api-Key": params.api_key},
json=request_body,
) as response:
if response.status != 201:
error_text = await response.text()
raise Exception(f"Ultravox API error {response.status}: {error_text}")
return (await response.json())["joinUrl"]
def _to_selected_tools(self, tool: ToolsSchema) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
@@ -342,7 +353,9 @@ class UltravoxRealtimeLLMService(LLMService):
socket_message = {
"type": "client_tool_result",
"invocationId": message.get("tool_call_id"),
"result": content if isinstance(content, str) else "".join(t.text for t in content),
"result": content
if isinstance(content, str)
else "".join(t.get("text") for t in content),
}
await self._send(socket_message)
@@ -350,6 +363,7 @@ class UltravoxRealtimeLLMService(LLMService):
"""Send user audio frame to Ultravox Realtime."""
if not self._socket:
return
self._last_user_id = frame.user_id if isinstance(frame, UserAudioRawFrame) else None
audio = frame.audio
if frame.sample_rate != self._sample_rate:
audio = await self._resampler.resample(audio, frame.sample_rate, self._sample_rate)
@@ -399,6 +413,8 @@ class UltravoxRealtimeLLMService(LLMService):
async def _receive_messages(self):
"""Receive messages from the Ultravox Realtime WebSocket."""
if not self._socket:
return
async for message in self._socket:
try:
if isinstance(message, bytes):
@@ -446,18 +462,16 @@ class UltravoxRealtimeLLMService(LLMService):
if not audio:
return
if not self._bot_responding:
await self.push_frame(BotStartedSpeakingFrame())
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(TTSStartedFrame())
self._bot_responding = "voice"
await self.push_frame(TTSAudioRawFrame(audio, self._sample_rate, 1))
async def _handle_response_end(self):
if self._bot_responding == "voice":
await self.push_frame(BotStoppedSpeakingFrame())
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
self._bot_responding = False
self._bot_responding = None
async def _handle_tool_invocation(
self, tool_name: str, invocation_id: str, parameters: Dict[str, Any]
@@ -476,7 +490,7 @@ class UltravoxRealtimeLLMService(LLMService):
async def _handle_user_transcript(self, text: str):
await self.push_frame(
TranscriptionFrame(
user_id="",
user_id=self._last_user_id or "",
timestamp=time_now_iso8601(),
result=text,
text=text,
@@ -487,9 +501,10 @@ class UltravoxRealtimeLLMService(LLMService):
async def _handle_agent_transcript(
self, medium: str, text: Optional[str], delta: Optional[str], final: bool
):
frame = LLMTextFrame(text=text or delta)
frame.skip_tts = medium == "voice"
await self.push_frame(frame)
if text or delta:
frame = LLMTextFrame(text=text or delta)
frame.skip_tts = medium == "voice"
await self.push_frame(frame)
if medium == "text":
if text:
await self.push_frame(LLMFullResponseStartFrame())
@@ -498,8 +513,8 @@ class UltravoxRealtimeLLMService(LLMService):
self._bot_responding = "text"
elif final:
await self.push_frame(LLMFullResponseEndFrame())
self._bot_responding = False
else:
self._bot_responding = None
elif delta:
await self.push_frame(TTSTextFrame(text=delta, aggregated_by=AggregationType.WORD))
def create_context_aggregator(