PR comments
Also satisfy some Pyright complaints and update default model
This commit is contained in:
@@ -36,17 +36,14 @@ transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=False,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=False,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=False,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .llm import UltravoxRealtimeLLMService
|
||||
|
||||
@@ -26,8 +26,6 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
AggregationType,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
@@ -44,6 +42,7 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
UserAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
@@ -115,7 +114,7 @@ class OneShotInputParams(BaseModel):
|
||||
api_key: str
|
||||
system_prompt: Optional[str] = None
|
||||
temperature: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||
model: str = "fixie-ai/ultravox"
|
||||
model: Optional[str] = None
|
||||
voice: Optional[uuid.UUID] = None
|
||||
metadata: Dict[str, str] = Field(default_factory=dict)
|
||||
max_duration: datetime.timedelta = Field(
|
||||
@@ -177,6 +176,7 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._disconnecting = False
|
||||
self._bot_responding: Literal[None, "text", "voice"] = None
|
||||
self._last_user_id: Optional[str] = None
|
||||
|
||||
self._sample_rate = 48000
|
||||
self._resampler = create_stream_resampler()
|
||||
@@ -193,61 +193,72 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
match self._params:
|
||||
case JoinUrlInputParams():
|
||||
join_url = self._params.join_url
|
||||
case AgentInputParams():
|
||||
request_body = {
|
||||
"templateContext": self._params.template_context,
|
||||
"metadata": self._params.metadata,
|
||||
"maxDuration": f"{self._params.max_duration.total_seconds():3f}s",
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
} | self._params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://api.ultravox.ai/api/agents/{self._params.agent_id}/calls",
|
||||
headers={"X-Api-Key": self._params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
join_url = (await response.json())["joinUrl"]
|
||||
case OneShotInputParams():
|
||||
request_body = {
|
||||
"systemPrompt": self._params.system_prompt,
|
||||
"temperature": self._params.temperature,
|
||||
"model": self._params.model,
|
||||
"voice": str(self._params.voice) if self._params.voice else None,
|
||||
"metadata": self._params.metadata,
|
||||
"maxDuration": f"{self._params.max_duration.total_seconds():3f}s",
|
||||
"selectedTools": self._to_selected_tools(self._selected_tools)
|
||||
if self._selected_tools
|
||||
else [],
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
} | self._params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.ultravox.ai/api/calls",
|
||||
headers={"X-Api-Key": self._params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
join_url = (await response.json())["joinUrl"]
|
||||
try:
|
||||
match self._params:
|
||||
case JoinUrlInputParams():
|
||||
join_url = self._params.join_url
|
||||
case AgentInputParams():
|
||||
join_url = await self._start_agent_call(self._params)
|
||||
case OneShotInputParams():
|
||||
join_url = await self._start_one_shot_call(self._params)
|
||||
|
||||
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
|
||||
self._socket = await websocket_client.connect(join_url)
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
|
||||
self._socket = await websocket_client.connect(join_url)
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
except Exception as e:
|
||||
await self.push_error("Failed to connect to Ultravox", e, fatal=True)
|
||||
|
||||
async def _start_agent_call(self, params: AgentInputParams) -> str:
|
||||
request_body = {
|
||||
"templateContext": params.template_context,
|
||||
"metadata": params.metadata,
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
}
|
||||
if params.max_duration:
|
||||
request_body["maxDuration"] = f"{params.max_duration.total_seconds():3f}s"
|
||||
request_body = request_body | params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://api.ultravox.ai/api/agents/{params.agent_id}/calls",
|
||||
headers={"X-Api-Key": params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
return (await response.json())["joinUrl"]
|
||||
|
||||
async def _start_one_shot_call(self, params: OneShotInputParams) -> str:
|
||||
request_body = {
|
||||
"systemPrompt": params.system_prompt,
|
||||
"temperature": params.temperature,
|
||||
"model": params.model,
|
||||
"voice": str(params.voice) if params.voice else None,
|
||||
"metadata": params.metadata,
|
||||
"maxDuration": f"{params.max_duration.total_seconds():3f}s",
|
||||
"selectedTools": self._to_selected_tools(self._selected_tools)
|
||||
if self._selected_tools
|
||||
else [],
|
||||
"medium": {
|
||||
"serverWebSocket": {
|
||||
"inputSampleRate": self._sample_rate,
|
||||
}
|
||||
},
|
||||
} | params.extra
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.ultravox.ai/api/calls",
|
||||
headers={"X-Api-Key": params.api_key},
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 201:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ultravox API error {response.status}: {error_text}")
|
||||
return (await response.json())["joinUrl"]
|
||||
|
||||
def _to_selected_tools(self, tool: ToolsSchema) -> List[Dict[str, Any]]:
|
||||
result: List[Dict[str, Any]] = []
|
||||
@@ -342,7 +353,9 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
socket_message = {
|
||||
"type": "client_tool_result",
|
||||
"invocationId": message.get("tool_call_id"),
|
||||
"result": content if isinstance(content, str) else "".join(t.text for t in content),
|
||||
"result": content
|
||||
if isinstance(content, str)
|
||||
else "".join(t.get("text") for t in content),
|
||||
}
|
||||
await self._send(socket_message)
|
||||
|
||||
@@ -350,6 +363,7 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
"""Send user audio frame to Ultravox Realtime."""
|
||||
if not self._socket:
|
||||
return
|
||||
self._last_user_id = frame.user_id if isinstance(frame, UserAudioRawFrame) else None
|
||||
audio = frame.audio
|
||||
if frame.sample_rate != self._sample_rate:
|
||||
audio = await self._resampler.resample(audio, frame.sample_rate, self._sample_rate)
|
||||
@@ -399,6 +413,8 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive messages from the Ultravox Realtime WebSocket."""
|
||||
if not self._socket:
|
||||
return
|
||||
async for message in self._socket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
@@ -446,18 +462,16 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
if not audio:
|
||||
return
|
||||
if not self._bot_responding:
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
self._bot_responding = "voice"
|
||||
await self.push_frame(TTSAudioRawFrame(audio, self._sample_rate, 1))
|
||||
|
||||
async def _handle_response_end(self):
|
||||
if self._bot_responding == "voice":
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._bot_responding = False
|
||||
self._bot_responding = None
|
||||
|
||||
async def _handle_tool_invocation(
|
||||
self, tool_name: str, invocation_id: str, parameters: Dict[str, Any]
|
||||
@@ -476,7 +490,7 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
async def _handle_user_transcript(self, text: str):
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
user_id="",
|
||||
user_id=self._last_user_id or "",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=text,
|
||||
text=text,
|
||||
@@ -487,9 +501,10 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
async def _handle_agent_transcript(
|
||||
self, medium: str, text: Optional[str], delta: Optional[str], final: bool
|
||||
):
|
||||
frame = LLMTextFrame(text=text or delta)
|
||||
frame.skip_tts = medium == "voice"
|
||||
await self.push_frame(frame)
|
||||
if text or delta:
|
||||
frame = LLMTextFrame(text=text or delta)
|
||||
frame.skip_tts = medium == "voice"
|
||||
await self.push_frame(frame)
|
||||
if medium == "text":
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
@@ -498,8 +513,8 @@ class UltravoxRealtimeLLMService(LLMService):
|
||||
self._bot_responding = "text"
|
||||
elif final:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
self._bot_responding = False
|
||||
else:
|
||||
self._bot_responding = None
|
||||
elif delta:
|
||||
await self.push_frame(TTSTextFrame(text=delta, aggregated_by=AggregationType.WORD))
|
||||
|
||||
def create_context_aggregator(
|
||||
|
||||
Reference in New Issue
Block a user