Files
pipecat/tests/test_elevenlabs_stt.py
2026-05-07 21:00:26 +00:00

109 lines
3.4 KiB
Python

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from urllib.parse import parse_qs, urlparse
import aiohttp
import pytest
from aiohttp import web
from pipecat.services.elevenlabs.stt import (
CommitStrategy,
ElevenLabsRealtimeSTTService,
ElevenLabsSTTService,
audio_format_from_sample_rate,
)
from pipecat.transcriptions.language import Language
@pytest.mark.asyncio
async def test_elevenlabs_stt_sends_keyterms_multipart_fields(aiohttp_client):
captured = {"headers": {}, "fields": []}
async def handler(request):
captured["headers"]["xi-api-key"] = request.headers.get("xi-api-key")
reader = await request.multipart()
async for part in reader:
if part.name == "file":
await part.read()
else:
captured["fields"].append((part.name, await part.text()))
return web.json_response({"text": "hello", "language_code": "eng", "words": []})
app = web.Application()
app.router.add_post("/v1/speech-to-text", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("/")).rstrip("/")
async with aiohttp.ClientSession() as session:
service = ElevenLabsSTTService(
api_key="test-key",
aiohttp_session=session,
base_url=base_url,
settings=ElevenLabsSTTService.Settings(
language=Language.EN,
keyterms=["Pipecat", "Scribe V2"],
),
)
result = await service._transcribe_audio(b"RIFF")
assert result["text"] == "hello"
assert captured["headers"]["xi-api-key"] == "test-key"
assert ("model_id", "scribe_v2") in captured["fields"]
assert ("language_code", "eng") in captured["fields"]
assert [value for name, value in captured["fields"] if name == "keyterms"] == [
"Pipecat",
"Scribe V2",
]
@pytest.mark.asyncio
async def test_elevenlabs_realtime_websocket_url_includes_keyterms(monkeypatch):
captured = {}
async def fake_websocket_connect(url, *, additional_headers):
captured["url"] = url
captured["headers"] = additional_headers
return object()
monkeypatch.setattr(
"pipecat.services.elevenlabs.stt.websocket_connect",
fake_websocket_connect,
)
service = ElevenLabsRealtimeSTTService(
api_key="test-key",
base_url="example.test",
commit_strategy=CommitStrategy.VAD,
sample_rate=16000,
include_timestamps=True,
settings=ElevenLabsRealtimeSTTService.Settings(
language=Language.EN,
keyterms=["Pipecat", "Scribe V2"],
vad_threshold=0.7,
),
)
service._audio_format = audio_format_from_sample_rate(16000)
await service._connect_websocket()
parsed = urlparse(captured["url"])
query = parse_qs(parsed.query)
assert parsed.scheme == "wss"
assert parsed.netloc == "example.test"
assert parsed.path == "/v1/speech-to-text/realtime"
assert query["model_id"] == ["scribe_v2_realtime"]
assert query["language_code"] == ["en"]
assert query["audio_format"] == ["pcm_16000"]
assert query["commit_strategy"] == ["vad"]
assert query["include_timestamps"] == ["true"]
assert query["vad_threshold"] == ["0.7"]
assert query["keyterms"] == ["Pipecat", "Scribe V2"]
assert captured["headers"] == {"xi-api-key": "test-key"}