diff --git a/changelog/4426.added.md b/changelog/4426.added.md new file mode 100644 index 000000000..86488dccc --- /dev/null +++ b/changelog/4426.added.md @@ -0,0 +1 @@ +- Added `keyterms` support to ElevenLabs STT services so Scribe V2 callers can bias transcription for both file-based and realtime transcription. diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index 6fed6c5e5..315cf5af6 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -19,6 +19,7 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass, field from enum import StrEnum from typing import Any +from urllib.parse import urlencode import aiohttp from loguru import logger @@ -36,7 +37,7 @@ from pipecat.frames.frames import ( VADUserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven +from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven, is_given from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99 from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService from pipecat.transcriptions.language import Language, resolve_language @@ -187,9 +188,11 @@ class ElevenLabsSTTSettings(STTSettings): Parameters: tag_audio_events: Whether to include audio events like (laughter), (coughing) in the transcription. + keyterms: List of key terms or phrases to bias transcription towards. """ tag_audio_events: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) + keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) @dataclass @@ -199,12 +202,14 @@ class ElevenLabsRealtimeSTTSettings(STTSettings): See ``ElevenLabsRealtimeSTTService.InputParams`` for detailed descriptions. Parameters: + keyterms: List of key terms or phrases to bias transcription towards. vad_silence_threshold_secs: Seconds of silence before VAD commits (0.3-3.0). vad_threshold: VAD sensitivity (0.1-0.9, lower is more sensitive). min_speech_duration_ms: Minimum speech duration for VAD (50-2000ms). min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms). """ + keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) vad_silence_threshold_secs: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) vad_threshold: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) min_speech_duration_ms: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) @@ -277,6 +282,7 @@ class ElevenLabsSTTService(SegmentedSTTService): model="scribe_v2", language=Language.EN, tag_audio_events=None, + keyterms=None, ) # 2. Apply direct init arg overrides (deprecated) @@ -355,6 +361,10 @@ class ElevenLabsSTTService(SegmentedSTTService): data.add_field("language_code", self._settings.language) if self._settings.tag_audio_events is not None: data.add_field("tag_audio_events", str(self._settings.tag_audio_events).lower()) + keyterms = self._settings.keyterms + if is_given(keyterms) and keyterms is not None: + for keyterm in keyterms: + data.add_field("keyterms", keyterm) async with self._session.post(url, data=data, headers=headers) as response: if response.status != 200: @@ -539,6 +549,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): vad_threshold=None, min_speech_duration_ms=None, min_silence_duration_ms=None, + keyterms=None, ) # 2. Apply direct init arg overrides (deprecated) @@ -771,6 +782,11 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): params.append(f"audio_format={self._audio_format}") params.append(f"commit_strategy={self._commit_strategy.value}") + keyterms = self._settings.keyterms + if is_given(keyterms) and keyterms is not None: + for keyterm in keyterms: + params.append(urlencode({"keyterms": keyterm})) + # Add optional parameters if self._include_timestamps: params.append(f"include_timestamps={str(self._include_timestamps).lower()}") diff --git a/tests/test_elevenlabs_stt.py b/tests/test_elevenlabs_stt.py new file mode 100644 index 000000000..10c937509 --- /dev/null +++ b/tests/test_elevenlabs_stt.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from urllib.parse import parse_qs, urlparse + +import aiohttp +import pytest +from aiohttp import web + +from pipecat.services.elevenlabs.stt import ( + CommitStrategy, + ElevenLabsRealtimeSTTService, + ElevenLabsSTTService, + audio_format_from_sample_rate, +) +from pipecat.transcriptions.language import Language + + +@pytest.mark.asyncio +async def test_elevenlabs_stt_sends_keyterms_multipart_fields(aiohttp_client): + captured = {"headers": {}, "fields": []} + + async def handler(request): + captured["headers"]["xi-api-key"] = request.headers.get("xi-api-key") + reader = await request.multipart() + + async for part in reader: + if part.name == "file": + await part.read() + else: + captured["fields"].append((part.name, await part.text())) + + return web.json_response({"text": "hello", "language_code": "eng", "words": []}) + + app = web.Application() + app.router.add_post("/v1/speech-to-text", handler) + client = await aiohttp_client(app) + base_url = str(client.make_url("/")).rstrip("/") + + async with aiohttp.ClientSession() as session: + service = ElevenLabsSTTService( + api_key="test-key", + aiohttp_session=session, + base_url=base_url, + settings=ElevenLabsSTTService.Settings( + language=Language.EN, + keyterms=["Pipecat", "Scribe V2"], + ), + ) + + result = await service._transcribe_audio(b"RIFF") + + assert result["text"] == "hello" + assert captured["headers"]["xi-api-key"] == "test-key" + assert ("model_id", "scribe_v2") in captured["fields"] + assert ("language_code", "eng") in captured["fields"] + assert [value for name, value in captured["fields"] if name == "keyterms"] == [ + "Pipecat", + "Scribe V2", + ] + + +@pytest.mark.asyncio +async def test_elevenlabs_realtime_websocket_url_includes_keyterms(monkeypatch): + captured = {} + + async def fake_websocket_connect(url, *, additional_headers): + captured["url"] = url + captured["headers"] = additional_headers + return object() + + monkeypatch.setattr( + "pipecat.services.elevenlabs.stt.websocket_connect", + fake_websocket_connect, + ) + + service = ElevenLabsRealtimeSTTService( + api_key="test-key", + base_url="example.test", + commit_strategy=CommitStrategy.VAD, + sample_rate=16000, + include_timestamps=True, + settings=ElevenLabsRealtimeSTTService.Settings( + language=Language.EN, + keyterms=["Pipecat", "Scribe V2"], + vad_threshold=0.7, + ), + ) + service._audio_format = audio_format_from_sample_rate(16000) + + await service._connect_websocket() + + parsed = urlparse(captured["url"]) + query = parse_qs(parsed.query) + assert parsed.scheme == "wss" + assert parsed.netloc == "example.test" + assert parsed.path == "/v1/speech-to-text/realtime" + assert query["model_id"] == ["scribe_v2_realtime"] + assert query["language_code"] == ["en"] + assert query["audio_format"] == ["pcm_16000"] + assert query["commit_strategy"] == ["vad"] + assert query["include_timestamps"] == ["true"] + assert query["vad_threshold"] == ["0.7"] + assert query["keyterms"] == ["Pipecat", "Scribe V2"] + assert captured["headers"] == {"xi-api-key": "test-key"}