Merge pull request #3940 from pipecat-ai/pk/grok-realtime-settings-pattern
Adopt the `settings` pattern for Grok Realtime session properties
This commit is contained in:
@@ -13,8 +13,9 @@ https://docs.x.ai/docs/guides/voice/agent
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, Mapping, Optional, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -34,7 +35,6 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesAppendFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
@@ -56,7 +56,13 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.services.settings import (
|
||||
NOT_GIVEN,
|
||||
LLMSettings,
|
||||
_NotGiven,
|
||||
_warn_deprecated_param,
|
||||
is_given,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from . import events
|
||||
@@ -88,9 +94,96 @@ class CurrentAudioResponse:
|
||||
|
||||
@dataclass
|
||||
class GrokRealtimeLLMSettings(LLMSettings):
|
||||
"""Settings for Grok Realtime LLM services."""
|
||||
"""Settings for Grok Realtime LLM services.
|
||||
|
||||
pass
|
||||
Parameters:
|
||||
session_properties: Grok Realtime session properties (voice, audio config,
|
||||
tools, etc.). ``instructions`` is synced bidirectionally with the
|
||||
top-level ``system_instruction`` field.
|
||||
"""
|
||||
|
||||
session_properties: events.SessionProperties | _NotGiven = field(
|
||||
default_factory=lambda: NOT_GIVEN
|
||||
)
|
||||
|
||||
# -- Bidirectional sync helpers ------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _sync_top_level_to_sp(settings: "GrokRealtimeLLMSettings"):
|
||||
"""Push top-level ``system_instruction`` into ``session_properties``."""
|
||||
if not is_given(settings.session_properties):
|
||||
return
|
||||
sp = settings.session_properties
|
||||
if is_given(settings.system_instruction):
|
||||
sp.instructions = settings.system_instruction
|
||||
|
||||
# -- apply_update override -----------------------------------------------
|
||||
|
||||
def apply_update(self, delta: "GrokRealtimeLLMSettings") -> Dict[str, Any]:
|
||||
"""Merge a delta, keeping ``system_instruction`` in sync with SP.
|
||||
|
||||
When the delta contains ``session_properties``, it **replaces** the
|
||||
stored SP wholesale (matching legacy behaviour). Top-level field
|
||||
values always take precedence over conflicting SP values.
|
||||
"""
|
||||
# 1. Let the base class handle all fields including session_properties
|
||||
# (wholesale replacement when given).
|
||||
changed = super().apply_update(delta)
|
||||
|
||||
# 2. SP → top-level: if the SP was just replaced and carries
|
||||
# instructions that the delta didn't set at top level, pull it up.
|
||||
if "session_properties" in changed and is_given(self.session_properties):
|
||||
sp = self.session_properties
|
||||
if "system_instruction" not in changed and sp.instructions is not None:
|
||||
old_si = self.system_instruction
|
||||
self.system_instruction = sp.instructions
|
||||
if old_si != self.system_instruction:
|
||||
changed["system_instruction"] = old_si
|
||||
|
||||
# 3. Top-level → SP: ensure SP mirrors the authoritative top-level
|
||||
# values. Covers all cases: top-level-only delta, SP-only delta,
|
||||
# and mixed deltas where top-level takes precedence.
|
||||
self._sync_top_level_to_sp(self)
|
||||
|
||||
return changed
|
||||
|
||||
# -- from_mapping override -----------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_mapping(
|
||||
cls: Type["GrokRealtimeLLMSettings"], settings: Mapping[str, Any]
|
||||
) -> "GrokRealtimeLLMSettings":
|
||||
"""Build a delta from a plain dict, routing SP keys into ``session_properties``.
|
||||
|
||||
Keys that correspond to ``SessionProperties`` fields are collected into
|
||||
a nested ``session_properties`` value. ``model`` is always routed to
|
||||
the top-level field. Unknown keys go to ``extra``.
|
||||
"""
|
||||
# Determine which keys belong to our own dataclass fields.
|
||||
own_field_names = {f.name for f in dataclass_fields(cls)} - {"extra"}
|
||||
|
||||
top: Dict[str, Any] = {}
|
||||
sp_dict: Dict[str, Any] = {}
|
||||
extra: Dict[str, Any] = {}
|
||||
|
||||
sp_keys = set(events.SessionProperties.model_fields.keys())
|
||||
|
||||
for key, value in settings.items():
|
||||
# Resolve aliases first
|
||||
canonical = cls._aliases.get(key, key)
|
||||
if canonical in own_field_names:
|
||||
top[canonical] = value
|
||||
elif canonical in sp_keys:
|
||||
sp_dict[canonical] = value
|
||||
else:
|
||||
extra[key] = value
|
||||
|
||||
if sp_dict:
|
||||
top["session_properties"] = events.SessionProperties(**sp_dict)
|
||||
|
||||
instance = cls(**top)
|
||||
instance.extra = extra
|
||||
return instance
|
||||
|
||||
|
||||
class GrokRealtimeLLMService(LLMService):
|
||||
@@ -132,6 +225,11 @@ class GrokRealtimeLLMService(LLMService):
|
||||
Defaults to "wss://api.x.ai/v1/realtime".
|
||||
session_properties: Configuration properties for the realtime session.
|
||||
If None, uses default SessionProperties with voice "Ara".
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=GrokRealtimeLLMSettings(session_properties=...)``
|
||||
instead.
|
||||
|
||||
To set a different voice, configure it in session_properties:
|
||||
|
||||
session_properties = events.SessionProperties(voice="Rex")
|
||||
@@ -154,9 +252,25 @@ class GrokRealtimeLLMService(LLMService):
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=False,
|
||||
user_turn_completion_config=None,
|
||||
session_properties=events.SessionProperties(),
|
||||
)
|
||||
|
||||
# 2. Apply settings delta (canonical API, always wins)
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
if session_properties is not None:
|
||||
_warn_deprecated_param(
|
||||
"session_properties",
|
||||
GrokRealtimeLLMSettings,
|
||||
"session_properties",
|
||||
)
|
||||
default_settings.session_properties = session_properties
|
||||
# Sync instructions from the deprecated SP arg to top-level
|
||||
if session_properties.instructions is not None:
|
||||
default_settings.system_instruction = session_properties.instructions
|
||||
|
||||
# Sync top-level system_instruction back into session_properties
|
||||
GrokRealtimeLLMSettings._sync_top_level_to_sp(default_settings)
|
||||
|
||||
# 3. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
@@ -168,7 +282,6 @@ class GrokRealtimeLLMService(LLMService):
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self._session_properties = session_properties or events.SessionProperties()
|
||||
|
||||
self._audio_input_paused = start_audio_paused
|
||||
self._websocket = None
|
||||
@@ -217,13 +330,13 @@ class GrokRealtimeLLMService(LLMService):
|
||||
Configured sample rate or None if not manually configured.
|
||||
For PCMU/PCMA formats, returns 8000 Hz (G.711 standard).
|
||||
"""
|
||||
if not self._session_properties.audio:
|
||||
if not self._settings.session_properties.audio:
|
||||
return None
|
||||
|
||||
audio_config = (
|
||||
self._session_properties.audio.input
|
||||
self._settings.session_properties.audio.input
|
||||
if direction == "input"
|
||||
else self._session_properties.audio.output
|
||||
else self._settings.session_properties.audio.output
|
||||
)
|
||||
|
||||
if audio_config and audio_config.format:
|
||||
@@ -253,8 +366,8 @@ class GrokRealtimeLLMService(LLMService):
|
||||
|
||||
def _is_turn_detection_enabled(self) -> bool:
|
||||
"""Check if server-side VAD is enabled."""
|
||||
if self._session_properties.turn_detection:
|
||||
return self._session_properties.turn_detection.type == "server_vad"
|
||||
if self._settings.session_properties.turn_detection:
|
||||
return self._settings.session_properties.turn_detection.type == "server_vad"
|
||||
return False
|
||||
|
||||
async def _handle_interruption(self):
|
||||
@@ -321,7 +434,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
input_sample_rate: Sample rate for audio input (Hz).
|
||||
output_sample_rate: Sample rate for audio output (Hz).
|
||||
"""
|
||||
props = self._session_properties
|
||||
props = self._settings.session_properties
|
||||
if not props.audio:
|
||||
props.audio = events.AudioConfiguration()
|
||||
if not props.audio.input:
|
||||
@@ -372,21 +485,6 @@ class GrokRealtimeLLMService(LLMService):
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
# Backward-compatible dict path: frame.settings contains SessionProperties
|
||||
# fields, not our Settings fields, so we construct SessionProperties
|
||||
# directly. The frame.delta path falls through to super, which calls
|
||||
# _update_settings → our override handles the rest.
|
||||
if isinstance(frame, LLMUpdateSettingsFrame) and frame.delta is None:
|
||||
# Capture current audio config before replacing session properties.
|
||||
input_rate = self._get_configured_sample_rate("input")
|
||||
output_rate = self._get_configured_sample_rate("output")
|
||||
self._session_properties = events.SessionProperties(**frame.settings)
|
||||
if input_rate and output_rate:
|
||||
self._ensure_audio_config(input_rate, output_rate)
|
||||
await self._send_session_update()
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
@@ -485,14 +583,27 @@ class GrokRealtimeLLMService(LLMService):
|
||||
await self.push_error(error_msg=f"Error sending client event: {e}", exception=e)
|
||||
|
||||
async def _update_settings(self, delta):
|
||||
"""Apply a settings delta."""
|
||||
"""Apply a settings delta, sending a session update when needed."""
|
||||
# Capture audio config before the update — a wholesale SP replacement
|
||||
# would lose it since the new SP likely has audio=None.
|
||||
input_rate = self._get_configured_sample_rate("input")
|
||||
output_rate = self._get_configured_sample_rate("output")
|
||||
|
||||
changed = await super()._update_settings(delta)
|
||||
self._warn_unhandled_updated_settings(changed.keys())
|
||||
|
||||
# Re-establish audio config if it was lost during SP replacement.
|
||||
if "session_properties" in changed and input_rate and output_rate:
|
||||
self._ensure_audio_config(input_rate, output_rate)
|
||||
|
||||
handled = {"session_properties", "system_instruction"}
|
||||
if changed.keys() & handled:
|
||||
await self._send_session_update()
|
||||
self._warn_unhandled_updated_settings(changed.keys() - handled)
|
||||
return changed
|
||||
|
||||
async def _send_session_update(self):
|
||||
"""Update session settings on the server."""
|
||||
settings = self._session_properties
|
||||
settings = self._settings.session_properties
|
||||
adapter: GrokRealtimeLLMAdapter = self.get_llm_adapter()
|
||||
|
||||
if self._context:
|
||||
|
||||
@@ -236,7 +236,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
api_key: OpenAI API key for authentication.
|
||||
model: OpenAI model name.
|
||||
|
||||
.. deprecated::
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=OpenAIRealtimeLLMSettings(model=...)`` instead.
|
||||
|
||||
This is a connection-level parameter set via the WebSocket URL query
|
||||
@@ -246,7 +246,7 @@ class OpenAIRealtimeLLMService(LLMService):
|
||||
session_properties: Configuration properties for the realtime session.
|
||||
If None, uses default SessionProperties.
|
||||
|
||||
.. deprecated::
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=OpenAIRealtimeLLMSettings(session_properties=...)``
|
||||
instead.
|
||||
settings: Runtime-updatable settings for this service.
|
||||
|
||||
@@ -12,6 +12,8 @@ import pytest
|
||||
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
|
||||
from pipecat.services.deepgram.stt_sagemaker import DeepgramSageMakerSTTSettings
|
||||
from pipecat.services.grok.realtime import events as grok_events
|
||||
from pipecat.services.grok.realtime.llm import GrokRealtimeLLMSettings
|
||||
from pipecat.services.openai.realtime import events
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMSettings
|
||||
from pipecat.services.settings import (
|
||||
@@ -815,3 +817,169 @@ class TestOpenAIRealtimeSettingsFromMapping:
|
||||
assert store.session_properties.instructions == "Be concise."
|
||||
assert store.session_properties.output_modalities == ["text"]
|
||||
assert store.system_instruction == "Be concise."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrokRealtimeLLMSettings: apply_update
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGrokRealtimeSettingsApplyUpdate:
|
||||
def _make_store(self, **kwargs) -> GrokRealtimeLLMSettings:
|
||||
"""Helper to build a store-mode GrokRealtimeLLMSettings."""
|
||||
defaults = dict(
|
||||
model=None,
|
||||
system_instruction=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=False,
|
||||
user_turn_completion_config=None,
|
||||
session_properties=grok_events.SessionProperties(),
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return GrokRealtimeLLMSettings(**defaults)
|
||||
|
||||
def test_top_level_system_instruction_syncs_to_sp(self):
|
||||
"""Updating top-level system_instruction should propagate to session_properties.instructions."""
|
||||
store = self._make_store()
|
||||
delta = GrokRealtimeLLMSettings(system_instruction="Be helpful.")
|
||||
changed = store.apply_update(delta)
|
||||
|
||||
assert "system_instruction" in changed
|
||||
assert store.system_instruction == "Be helpful."
|
||||
assert store.session_properties.instructions == "Be helpful."
|
||||
|
||||
def test_sp_replaces_wholesale(self):
|
||||
"""session_properties in delta replaces the entire stored SP."""
|
||||
store = self._make_store(
|
||||
session_properties=grok_events.SessionProperties(
|
||||
voice="Rex",
|
||||
instructions="Old instructions.",
|
||||
),
|
||||
system_instruction="Old instructions.",
|
||||
)
|
||||
|
||||
new_sp = grok_events.SessionProperties(voice="Sal")
|
||||
delta = GrokRealtimeLLMSettings(session_properties=new_sp)
|
||||
changed = store.apply_update(delta)
|
||||
|
||||
assert "session_properties" in changed
|
||||
assert store.session_properties.voice == "Sal"
|
||||
# instructions is synced from top-level system_instruction
|
||||
assert store.session_properties.instructions == "Old instructions."
|
||||
|
||||
def test_sp_instructions_syncs_to_top_level(self):
|
||||
"""session_properties.instructions should sync to top-level system_instruction."""
|
||||
store = self._make_store()
|
||||
new_sp = grok_events.SessionProperties(instructions="New instructions.")
|
||||
delta = GrokRealtimeLLMSettings(session_properties=new_sp)
|
||||
changed = store.apply_update(delta)
|
||||
|
||||
assert "system_instruction" in changed
|
||||
assert store.system_instruction == "New instructions."
|
||||
assert store.session_properties.instructions == "New instructions."
|
||||
|
||||
def test_top_level_si_takes_precedence_over_sp_instructions(self):
|
||||
"""When both system_instruction and SP.instructions are in delta, top-level wins."""
|
||||
store = self._make_store()
|
||||
new_sp = grok_events.SessionProperties(instructions="sp instructions")
|
||||
delta = GrokRealtimeLLMSettings(
|
||||
system_instruction="top instructions",
|
||||
session_properties=new_sp,
|
||||
)
|
||||
store.apply_update(delta)
|
||||
|
||||
assert store.system_instruction == "top instructions"
|
||||
assert store.session_properties.instructions == "top instructions"
|
||||
|
||||
def test_non_synced_field_update_does_not_affect_sp(self):
|
||||
"""Updating a non-synced field like temperature shouldn't touch session_properties."""
|
||||
store = self._make_store(
|
||||
session_properties=grok_events.SessionProperties(instructions="Keep me."),
|
||||
system_instruction="Keep me.",
|
||||
)
|
||||
original_sp = store.session_properties
|
||||
|
||||
delta = GrokRealtimeLLMSettings(temperature=0.5)
|
||||
changed = store.apply_update(delta)
|
||||
|
||||
assert "temperature" in changed
|
||||
assert store.temperature == 0.5
|
||||
# SP should be untouched (same object)
|
||||
assert store.session_properties is original_sp
|
||||
assert store.session_properties.instructions == "Keep me."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrokRealtimeLLMSettings: from_mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGrokRealtimeSettingsFromMapping:
|
||||
def test_sp_keys_route_to_session_properties(self):
|
||||
"""SessionProperties fields (instructions, voice, etc.) route into nested SP."""
|
||||
delta = GrokRealtimeLLMSettings.from_mapping(
|
||||
{"instructions": "Be concise.", "voice": "Rex"}
|
||||
)
|
||||
assert is_given(delta.session_properties)
|
||||
assert delta.session_properties.instructions == "Be concise."
|
||||
assert delta.session_properties.voice == "Rex"
|
||||
|
||||
def test_model_routes_to_top_level(self):
|
||||
"""model should go to the top-level field, not session_properties."""
|
||||
delta = GrokRealtimeLLMSettings.from_mapping({"model": "some-model"})
|
||||
assert delta.model == "some-model"
|
||||
# No session_properties should be created since no SP keys were present
|
||||
assert not is_given(delta.session_properties)
|
||||
|
||||
def test_unknown_keys_go_to_extra(self):
|
||||
"""Unrecognized keys should land in extra."""
|
||||
delta = GrokRealtimeLLMSettings.from_mapping({"unknown_param": 42})
|
||||
assert not is_given(delta.model)
|
||||
assert not is_given(delta.session_properties)
|
||||
assert delta.extra == {"unknown_param": 42}
|
||||
|
||||
def test_mixed_keys(self):
|
||||
"""model + SP keys + unknown keys are routed correctly."""
|
||||
delta = GrokRealtimeLLMSettings.from_mapping(
|
||||
{
|
||||
"model": "some-model",
|
||||
"instructions": "Be helpful.",
|
||||
"unknown": "val",
|
||||
}
|
||||
)
|
||||
assert delta.model == "some-model"
|
||||
assert is_given(delta.session_properties)
|
||||
assert delta.session_properties.instructions == "Be helpful."
|
||||
assert delta.extra == {"unknown": "val"}
|
||||
|
||||
def test_roundtrip_from_mapping_apply_update(self):
|
||||
"""Simulate dict-style update: from_mapping -> apply_update."""
|
||||
store = GrokRealtimeLLMSettings(
|
||||
model=None,
|
||||
system_instruction=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
seed=None,
|
||||
filter_incomplete_user_turns=False,
|
||||
user_turn_completion_config=None,
|
||||
session_properties=grok_events.SessionProperties(),
|
||||
)
|
||||
|
||||
raw = {"instructions": "Be concise.", "voice": "Eve"}
|
||||
delta = GrokRealtimeLLMSettings.from_mapping(raw)
|
||||
changed = store.apply_update(delta)
|
||||
|
||||
assert "session_properties" in changed
|
||||
assert store.session_properties.instructions == "Be concise."
|
||||
assert store.session_properties.voice == "Eve"
|
||||
assert store.system_instruction == "Be concise."
|
||||
|
||||
Reference in New Issue
Block a user