Add ElevenLabs STT keyterms support

This commit is contained in:
Marcelo Díaz
2026-05-06 01:30:49 +00:00
parent fa31a2fd63
commit edfcd6948b
3 changed files with 126 additions and 1 deletions

1
changelog/4426.added.md Normal file
View File

@@ -0,0 +1 @@
- Added `keyterms` support to ElevenLabs STT services so Scribe V2 callers can bias transcription for both file-based and realtime transcription.

View File

@@ -19,6 +19,7 @@ from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any
from urllib.parse import urlencode
import aiohttp
from loguru import logger
@@ -36,7 +37,7 @@ from pipecat.frames.frames import (
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven, is_given
from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
from pipecat.transcriptions.language import Language, resolve_language
@@ -187,9 +188,11 @@ class ElevenLabsSTTSettings(STTSettings):
Parameters:
tag_audio_events: Whether to include audio events like (laughter),
(coughing) in the transcription.
keyterms: List of key terms or phrases to bias transcription towards.
"""
tag_audio_events: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
@dataclass
@@ -199,12 +202,14 @@ class ElevenLabsRealtimeSTTSettings(STTSettings):
See ``ElevenLabsRealtimeSTTService.InputParams`` for detailed descriptions.
Parameters:
keyterms: List of key terms or phrases to bias transcription towards.
vad_silence_threshold_secs: Seconds of silence before VAD commits (0.3-3.0).
vad_threshold: VAD sensitivity (0.1-0.9, lower is more sensitive).
min_speech_duration_ms: Minimum speech duration for VAD (50-2000ms).
min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms).
"""
keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
vad_silence_threshold_secs: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
vad_threshold: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
min_speech_duration_ms: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
@@ -277,6 +282,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
model="scribe_v2",
language=Language.EN,
tag_audio_events=None,
keyterms=None,
)
# 2. Apply direct init arg overrides (deprecated)
@@ -355,6 +361,10 @@ class ElevenLabsSTTService(SegmentedSTTService):
data.add_field("language_code", self._settings.language)
if self._settings.tag_audio_events is not None:
data.add_field("tag_audio_events", str(self._settings.tag_audio_events).lower())
keyterms = self._settings.keyterms
if is_given(keyterms) and keyterms is not None:
for keyterm in keyterms:
data.add_field("keyterms", keyterm)
async with self._session.post(url, data=data, headers=headers) as response:
if response.status != 200:
@@ -539,6 +549,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
vad_threshold=None,
min_speech_duration_ms=None,
min_silence_duration_ms=None,
keyterms=None,
)
# 2. Apply direct init arg overrides (deprecated)
@@ -771,6 +782,11 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
params.append(f"audio_format={self._audio_format}")
params.append(f"commit_strategy={self._commit_strategy.value}")
keyterms = self._settings.keyterms
if is_given(keyterms) and keyterms is not None:
for keyterm in keyterms:
params.append(urlencode({"keyterms": keyterm}))
# Add optional parameters
if self._include_timestamps:
params.append(f"include_timestamps={str(self._include_timestamps).lower()}")

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from urllib.parse import parse_qs, urlparse
import aiohttp
import pytest
from aiohttp import web
from pipecat.services.elevenlabs.stt import (
CommitStrategy,
ElevenLabsRealtimeSTTService,
ElevenLabsSTTService,
audio_format_from_sample_rate,
)
from pipecat.transcriptions.language import Language
@pytest.mark.asyncio
async def test_elevenlabs_stt_sends_keyterms_multipart_fields(aiohttp_client):
captured = {"headers": {}, "fields": []}
async def handler(request):
captured["headers"]["xi-api-key"] = request.headers.get("xi-api-key")
reader = await request.multipart()
async for part in reader:
if part.name == "file":
await part.read()
else:
captured["fields"].append((part.name, await part.text()))
return web.json_response({"text": "hello", "language_code": "eng", "words": []})
app = web.Application()
app.router.add_post("/v1/speech-to-text", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("/")).rstrip("/")
async with aiohttp.ClientSession() as session:
service = ElevenLabsSTTService(
api_key="test-key",
aiohttp_session=session,
base_url=base_url,
settings=ElevenLabsSTTService.Settings(
language=Language.EN,
keyterms=["Pipecat", "Scribe V2"],
),
)
result = await service._transcribe_audio(b"RIFF")
assert result["text"] == "hello"
assert captured["headers"]["xi-api-key"] == "test-key"
assert ("model_id", "scribe_v2") in captured["fields"]
assert ("language_code", "eng") in captured["fields"]
assert [value for name, value in captured["fields"] if name == "keyterms"] == [
"Pipecat",
"Scribe V2",
]
@pytest.mark.asyncio
async def test_elevenlabs_realtime_websocket_url_includes_keyterms(monkeypatch):
captured = {}
async def fake_websocket_connect(url, *, additional_headers):
captured["url"] = url
captured["headers"] = additional_headers
return object()
monkeypatch.setattr(
"pipecat.services.elevenlabs.stt.websocket_connect",
fake_websocket_connect,
)
service = ElevenLabsRealtimeSTTService(
api_key="test-key",
base_url="example.test",
commit_strategy=CommitStrategy.VAD,
sample_rate=16000,
include_timestamps=True,
settings=ElevenLabsRealtimeSTTService.Settings(
language=Language.EN,
keyterms=["Pipecat", "Scribe V2"],
vad_threshold=0.7,
),
)
service._audio_format = audio_format_from_sample_rate(16000)
await service._connect_websocket()
parsed = urlparse(captured["url"])
query = parse_qs(parsed.query)
assert parsed.scheme == "wss"
assert parsed.netloc == "example.test"
assert parsed.path == "/v1/speech-to-text/realtime"
assert query["model_id"] == ["scribe_v2_realtime"]
assert query["language_code"] == ["en"]
assert query["audio_format"] == ["pcm_16000"]
assert query["commit_strategy"] == ["vad"]
assert query["include_timestamps"] == ["true"]
assert query["vad_threshold"] == ["0.7"]
assert query["keyterms"] == ["Pipecat", "Scribe V2"]
assert captured["headers"] == {"xi-api-key": "test-key"}