Improve sample_rate handling in GrokRealtimeLLMService
This commit is contained in:
@@ -25,7 +25,6 @@ Usage:
|
||||
python 50-grok-realtime.py --transport daily
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
@@ -37,7 +36,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
# Note: Grok has built-in server-side VAD, so we don't need local VAD
|
||||
# from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
|
||||
from pipecat.frames.frames import LLMRunFrame, TranscriptionMessage
|
||||
from pipecat.observers.loggers.transcription_log_observer import (
|
||||
TranscriptionLogObserver,
|
||||
)
|
||||
|
||||
@@ -27,14 +27,12 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
@@ -57,7 +55,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM, LLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from . import events
|
||||
@@ -114,7 +111,6 @@ class GrokRealtimeLLMService(LLMService):
|
||||
base_url: str = "wss://api.x.ai/v1/realtime",
|
||||
session_properties: Optional[events.SessionProperties] = None,
|
||||
start_audio_paused: bool = False,
|
||||
sample_rate: int = 24000,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Grok Realtime Voice Agent LLM service.
|
||||
@@ -128,18 +124,15 @@ class GrokRealtimeLLMService(LLMService):
|
||||
session_properties: Configuration properties for the realtime session.
|
||||
If None, uses default SessionProperties with the specified voice.
|
||||
start_audio_paused: Whether to start with audio input paused. Defaults to False.
|
||||
sample_rate: Audio sample rate in Hz. Supported: 8000, 16000, 21050, 24000,
|
||||
32000, 44100, 48000. Defaults to 24000.
|
||||
**kwargs: Additional arguments passed to parent LLMService.
|
||||
"""
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self._sample_rate = sample_rate
|
||||
self._voice = voice
|
||||
|
||||
# Initialize session_properties with voice and audio config
|
||||
# Initialize session_properties
|
||||
if session_properties:
|
||||
self._session_properties = session_properties
|
||||
# Ensure voice is set
|
||||
@@ -149,10 +142,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
self._session_properties = events.SessionProperties(
|
||||
voice=voice,
|
||||
turn_detection=events.TurnDetection(type="server_vad"),
|
||||
audio=events.AudioConfiguration(
|
||||
input=events.AudioInput(format=events.PCMAudioFormat(rate=sample_rate)),
|
||||
output=events.AudioOutput(format=events.PCMAudioFormat(rate=sample_rate)),
|
||||
),
|
||||
# Audio config will be set in start() based on PipelineParams
|
||||
)
|
||||
|
||||
self._audio_input_paused = start_audio_paused
|
||||
@@ -192,6 +182,50 @@ class GrokRealtimeLLMService(LLMService):
|
||||
"""
|
||||
self._audio_input_paused = paused
|
||||
|
||||
def _get_configured_sample_rate(self, direction: str) -> Optional[int]:
|
||||
"""Get manually configured sample rate for input or output.
|
||||
|
||||
Args:
|
||||
direction: Either "input" or "output".
|
||||
|
||||
Returns:
|
||||
Configured sample rate or None if not manually configured.
|
||||
For PCMU/PCMA formats, returns 8000 Hz (G.711 standard).
|
||||
"""
|
||||
if not self._session_properties.audio:
|
||||
return None
|
||||
|
||||
audio_config = (
|
||||
self._session_properties.audio.input
|
||||
if direction == "input"
|
||||
else self._session_properties.audio.output
|
||||
)
|
||||
|
||||
if audio_config and audio_config.format:
|
||||
# PCM format has configurable rate
|
||||
if hasattr(audio_config.format, "rate"):
|
||||
return audio_config.format.rate
|
||||
# PCMU/PCMA formats are fixed at 8000 Hz (G.711 standard)
|
||||
elif audio_config.format.type in ("audio/pcmu", "audio/pcma"):
|
||||
return 8000
|
||||
|
||||
return None
|
||||
|
||||
def _get_output_sample_rate(self) -> int:
|
||||
"""Get the output sample rate from session properties.
|
||||
|
||||
Returns:
|
||||
Output sample rate in Hz.
|
||||
|
||||
Note:
|
||||
This assumes start() has been called, which guarantees
|
||||
session_properties.audio.output exists.
|
||||
"""
|
||||
rate = self._get_configured_sample_rate("output")
|
||||
if rate is None:
|
||||
raise RuntimeError("Output sample rate not configured.")
|
||||
return rate
|
||||
|
||||
def _is_turn_detection_enabled(self) -> bool:
|
||||
"""Check if server-side VAD is enabled."""
|
||||
if self._session_properties.turn_detection:
|
||||
@@ -230,7 +264,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
) -> int:
|
||||
"""Calculate audio duration in milliseconds based on PCM audio parameters."""
|
||||
if sample_rate is None:
|
||||
sample_rate = self._sample_rate
|
||||
sample_rate = self._get_output_sample_rate()
|
||||
samples = total_bytes / bytes_per_sample
|
||||
duration_seconds = samples / sample_rate
|
||||
return int(duration_seconds * 1000)
|
||||
@@ -260,6 +294,23 @@ class GrokRealtimeLLMService(LLMService):
|
||||
frame: The start frame triggering service initialization.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
# Ensure audio configuration exists with both input and output
|
||||
if not self._session_properties.audio:
|
||||
self._session_properties.audio = events.AudioConfiguration()
|
||||
|
||||
# Fill in missing input configuration
|
||||
if not self._session_properties.audio.input:
|
||||
self._session_properties.audio.input = events.AudioInput(
|
||||
format=events.PCMAudioFormat(rate=frame.audio_in_sample_rate)
|
||||
)
|
||||
|
||||
# Fill in missing output configuration
|
||||
if not self._session_properties.audio.output:
|
||||
self._session_properties.audio.output = events.AudioOutput(
|
||||
format=events.PCMAudioFormat(rate=frame.audio_out_sample_rate)
|
||||
)
|
||||
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -501,7 +552,7 @@ class GrokRealtimeLLMService(LLMService):
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=audio,
|
||||
sample_rate=self._sample_rate,
|
||||
sample_rate=self._get_output_sample_rate(),
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
|
||||
43
uv.lock
generated
43
uv.lock
generated
@@ -612,11 +612,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.8.3"
|
||||
version = "2025.11.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386, upload-time = "2025-08-03T03:07:47.08Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4044,7 +4044,7 @@ soundfile = [
|
||||
{ name = "soundfile" },
|
||||
]
|
||||
speechmatics = [
|
||||
{ name = "speechmatics-rt" },
|
||||
{ name = "speechmatics-voice", extra = ["smart"] },
|
||||
]
|
||||
strands = [
|
||||
{ name = "strands-agents" },
|
||||
@@ -4192,7 +4192,7 @@ requires-dist = [
|
||||
{ name = "simli-ai", marker = "extra == 'simli'", specifier = "~=1.0.3" },
|
||||
{ name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.1" },
|
||||
{ name = "soxr", specifier = "~=0.5.0" },
|
||||
{ name = "speechmatics-rt", marker = "extra == 'speechmatics'", specifier = ">=0.5.0" },
|
||||
{ name = "speechmatics-voice", extras = ["smart"], marker = "extra == 'speechmatics'", specifier = ">=0.2.4" },
|
||||
{ name = "strands-agents", marker = "extra == 'strands'", specifier = ">=1.9.1,<2" },
|
||||
{ name = "tenacity", marker = "extra == 'livekit'", specifier = ">=8.2.3,<10.0.0" },
|
||||
{ name = "timm", marker = "extra == 'moondream'", specifier = "~=1.0.13" },
|
||||
@@ -5917,14 +5917,35 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "speechmatics-rt"
|
||||
version = "0.5.0"
|
||||
version = "0.5.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/26/10359e1f16c2aa6a198eb11a9056f4a86a8bb8d4e610bbbe4a118b227b59/speechmatics_rt-0.5.0.tar.gz", hash = "sha256:ca974a186a012f946fd997deeaf3bf1c4f203f6d6e05a866172d27709183afc8", size = 26832, upload-time = "2025-10-15T15:54:25.695Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/a3/bb4d063a4405744951066c45ffbf7cd714a6fc00a20ef0cc83fe2494ed79/speechmatics_rt-0.5.3.tar.gz", hash = "sha256:c98d21041e5a0c90a66e463c3d5b98879c17eac0bbebb4100fd9d0f2b330bb19", size = 27333, upload-time = "2025-12-16T19:20:50.199Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/47/2e/9931ebe9360e9d385c68826b33137c2c9a4cfa361cd929d1ac6e72ebfe53/speechmatics_rt-0.5.0-py3-none-any.whl", hash = "sha256:58151488f891fa00cf7054f0cfab1b1eb94b55c3441be587f7941c726caef991", size = 32850, upload-time = "2025-10-15T15:54:24.5Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/5a/35dd924f9bfeb1604e01806ad0e16a9c596f3c44d13e66794f10d10f828b/speechmatics_rt-0.5.3-py3-none-any.whl", hash = "sha256:12f97f19bb989852b8ff3c6d1e28f4f0ea6fd9356e19da75d0e9877545931ce6", size = 33365, upload-time = "2025-12-16T19:20:49.031Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "speechmatics-voice"
|
||||
version = "0.2.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "speechmatics-rt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b2/f9/9d81e4abe9ae1c8745372eaf43523213b0333e9721699fb0f3d3bff6c17e/speechmatics_voice-0.2.4.tar.gz", hash = "sha256:e3b5c7a8c24fa7d555b80a72ab181797665c74944400468ca5fb7e54b5f9eae6", size = 60852, upload-time = "2025-12-17T23:22:13.437Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/69/a6/401dba9be6be914e57b7814360ba0bece55f24140bb7d5c3dc5f07bcd77f/speechmatics_voice-0.2.4-py3-none-any.whl", hash = "sha256:71d0f5272c2db1221422ab19b6c898ea7b38f9fb7f523904f54a4d8c3e4cef12", size = 57056, upload-time = "2025-12-17T23:22:11.837Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
smart = [
|
||||
{ name = "certifi" },
|
||||
{ name = "onnxruntime" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6508,7 +6529,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.56.2"
|
||||
version = "4.57.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
@@ -6522,9 +6543,9 @@ dependencies = [
|
||||
{ name = "tokenizers" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e5/82/0bcfddd134cdf53440becb5e738257cc3cf34cf229d63b57bfd288e6579f/transformers-4.56.2.tar.gz", hash = "sha256:5e7c623e2d7494105c726dd10f6f90c2c99a55ebe86eef7233765abd0cb1c529", size = 9844296, upload-time = "2025-09-19T15:16:26.778Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dd/70/d42a739e8dfde3d92bb2fff5819cbf331fe9657323221e79415cd5eb65ee/transformers-4.57.3.tar.gz", hash = "sha256:df4945029aaddd7c09eec5cad851f30662f8bd1746721b34cc031d70c65afebc", size = 10139680, upload-time = "2025-11-25T15:51:30.139Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/70/26/2591b48412bde75e33bfd292034103ffe41743cacd03120e3242516cd143/transformers-4.56.2-py3-none-any.whl", hash = "sha256:79c03d0e85b26cb573c109ff9eafa96f3c8d4febfd8a0774e8bba32702dd6dde", size = 11608055, upload-time = "2025-09-19T15:16:23.736Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user