Add ElevenLabs STT keyterms support
This commit is contained in:
1
changelog/4426.added.md
Normal file
1
changelog/4426.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `keyterms` support to ElevenLabs STT services so Scribe V2 callers can bias transcription for both file-based and realtime transcription.
|
||||
@@ -19,6 +19,7 @@ from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
@@ -36,7 +37,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven
|
||||
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven, is_given
|
||||
from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -187,9 +188,11 @@ class ElevenLabsSTTSettings(STTSettings):
|
||||
Parameters:
|
||||
tag_audio_events: Whether to include audio events like (laughter),
|
||||
(coughing) in the transcription.
|
||||
keyterms: List of key terms or phrases to bias transcription towards.
|
||||
"""
|
||||
|
||||
tag_audio_events: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -199,12 +202,14 @@ class ElevenLabsRealtimeSTTSettings(STTSettings):
|
||||
See ``ElevenLabsRealtimeSTTService.InputParams`` for detailed descriptions.
|
||||
|
||||
Parameters:
|
||||
keyterms: List of key terms or phrases to bias transcription towards.
|
||||
vad_silence_threshold_secs: Seconds of silence before VAD commits (0.3-3.0).
|
||||
vad_threshold: VAD sensitivity (0.1-0.9, lower is more sensitive).
|
||||
min_speech_duration_ms: Minimum speech duration for VAD (50-2000ms).
|
||||
min_silence_duration_ms: Minimum silence duration for VAD (50-2000ms).
|
||||
"""
|
||||
|
||||
keyterms: list[str] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
vad_silence_threshold_secs: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
vad_threshold: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
min_speech_duration_ms: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
@@ -277,6 +282,7 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
model="scribe_v2",
|
||||
language=Language.EN,
|
||||
tag_audio_events=None,
|
||||
keyterms=None,
|
||||
)
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
@@ -355,6 +361,10 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
data.add_field("language_code", self._settings.language)
|
||||
if self._settings.tag_audio_events is not None:
|
||||
data.add_field("tag_audio_events", str(self._settings.tag_audio_events).lower())
|
||||
keyterms = self._settings.keyterms
|
||||
if is_given(keyterms) and keyterms is not None:
|
||||
for keyterm in keyterms:
|
||||
data.add_field("keyterms", keyterm)
|
||||
|
||||
async with self._session.post(url, data=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
@@ -539,6 +549,7 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
vad_threshold=None,
|
||||
min_speech_duration_ms=None,
|
||||
min_silence_duration_ms=None,
|
||||
keyterms=None,
|
||||
)
|
||||
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
@@ -771,6 +782,11 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
params.append(f"audio_format={self._audio_format}")
|
||||
params.append(f"commit_strategy={self._commit_strategy.value}")
|
||||
|
||||
keyterms = self._settings.keyterms
|
||||
if is_given(keyterms) and keyterms is not None:
|
||||
for keyterm in keyterms:
|
||||
params.append(urlencode({"keyterms": keyterm}))
|
||||
|
||||
# Add optional parameters
|
||||
if self._include_timestamps:
|
||||
params.append(f"include_timestamps={str(self._include_timestamps).lower()}")
|
||||
|
||||
108
tests/test_elevenlabs_stt.py
Normal file
108
tests/test_elevenlabs_stt.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from pipecat.services.elevenlabs.stt import (
|
||||
CommitStrategy,
|
||||
ElevenLabsRealtimeSTTService,
|
||||
ElevenLabsSTTService,
|
||||
audio_format_from_sample_rate,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_elevenlabs_stt_sends_keyterms_multipart_fields(aiohttp_client):
|
||||
captured = {"headers": {}, "fields": []}
|
||||
|
||||
async def handler(request):
|
||||
captured["headers"]["xi-api-key"] = request.headers.get("xi-api-key")
|
||||
reader = await request.multipart()
|
||||
|
||||
async for part in reader:
|
||||
if part.name == "file":
|
||||
await part.read()
|
||||
else:
|
||||
captured["fields"].append((part.name, await part.text()))
|
||||
|
||||
return web.json_response({"text": "hello", "language_code": "eng", "words": []})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/v1/speech-to-text", handler)
|
||||
client = await aiohttp_client(app)
|
||||
base_url = str(client.make_url("/")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
service = ElevenLabsSTTService(
|
||||
api_key="test-key",
|
||||
aiohttp_session=session,
|
||||
base_url=base_url,
|
||||
settings=ElevenLabsSTTService.Settings(
|
||||
language=Language.EN,
|
||||
keyterms=["Pipecat", "Scribe V2"],
|
||||
),
|
||||
)
|
||||
|
||||
result = await service._transcribe_audio(b"RIFF")
|
||||
|
||||
assert result["text"] == "hello"
|
||||
assert captured["headers"]["xi-api-key"] == "test-key"
|
||||
assert ("model_id", "scribe_v2") in captured["fields"]
|
||||
assert ("language_code", "eng") in captured["fields"]
|
||||
assert [value for name, value in captured["fields"] if name == "keyterms"] == [
|
||||
"Pipecat",
|
||||
"Scribe V2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_elevenlabs_realtime_websocket_url_includes_keyterms(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
async def fake_websocket_connect(url, *, additional_headers):
|
||||
captured["url"] = url
|
||||
captured["headers"] = additional_headers
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.elevenlabs.stt.websocket_connect",
|
||||
fake_websocket_connect,
|
||||
)
|
||||
|
||||
service = ElevenLabsRealtimeSTTService(
|
||||
api_key="test-key",
|
||||
base_url="example.test",
|
||||
commit_strategy=CommitStrategy.VAD,
|
||||
sample_rate=16000,
|
||||
include_timestamps=True,
|
||||
settings=ElevenLabsRealtimeSTTService.Settings(
|
||||
language=Language.EN,
|
||||
keyterms=["Pipecat", "Scribe V2"],
|
||||
vad_threshold=0.7,
|
||||
),
|
||||
)
|
||||
service._audio_format = audio_format_from_sample_rate(16000)
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
parsed = urlparse(captured["url"])
|
||||
query = parse_qs(parsed.query)
|
||||
assert parsed.scheme == "wss"
|
||||
assert parsed.netloc == "example.test"
|
||||
assert parsed.path == "/v1/speech-to-text/realtime"
|
||||
assert query["model_id"] == ["scribe_v2_realtime"]
|
||||
assert query["language_code"] == ["en"]
|
||||
assert query["audio_format"] == ["pcm_16000"]
|
||||
assert query["commit_strategy"] == ["vad"]
|
||||
assert query["include_timestamps"] == ["true"]
|
||||
assert query["vad_threshold"] == ["0.7"]
|
||||
assert query["keyterms"] == ["Pipecat", "Scribe V2"]
|
||||
assert captured["headers"] == {"xi-api-key": "test-key"}
|
||||
Reference in New Issue
Block a user