Propagate Soniox token language
This commit is contained in:
@@ -58,6 +58,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# Add strict mode to enforce the language hints
|
||||
language_hints=[Language.EN],
|
||||
language_hints_strict=True,
|
||||
enable_language_identification=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -201,6 +201,19 @@ def _prepare_language_hints(
|
||||
return list(set(prepared_languages))
|
||||
|
||||
|
||||
def _language_from_tokens(tokens: list[dict]) -> Language | None:
|
||||
for token in reversed(tokens):
|
||||
language = token.get("language")
|
||||
if not language:
|
||||
continue
|
||||
try:
|
||||
return Language(language)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SonioxSTTSettings(STTSettings):
|
||||
"""Settings for SonioxSTTService.
|
||||
@@ -557,6 +570,7 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
async def send_endpoint_transcript():
|
||||
if self._final_transcription_buffer:
|
||||
text = "".join(map(lambda token: token["text"], self._final_transcription_buffer))
|
||||
language = _language_from_tokens(self._final_transcription_buffer)
|
||||
# Soniox only pushes TranscriptionFrame when an end token is received,
|
||||
# so every TranscriptionFrame is inherently finalized
|
||||
await self.push_frame(
|
||||
@@ -564,11 +578,12 @@ class SonioxSTTService(WebsocketSTTService):
|
||||
text=text,
|
||||
user_id=self._user_id,
|
||||
timestamp=time_now_iso8601(),
|
||||
language=language,
|
||||
result=self._final_transcription_buffer,
|
||||
finalized=True,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(text, is_final=True)
|
||||
await self._handle_transcription(text, is_final=True, language=language)
|
||||
await self.stop_processing_metrics()
|
||||
self._final_transcription_buffer = []
|
||||
|
||||
|
||||
163
tests/test_soniox_stt.py
Normal file
163
tests/test_soniox_stt.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.frames.frames import TranscriptionFrame
|
||||
from pipecat.services.soniox.stt import END_TOKEN, SonioxSTTService, _language_from_tokens
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class _FakeWebsocket:
|
||||
def __init__(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def __aiter__(self):
|
||||
return self._iter_messages()
|
||||
|
||||
async def _iter_messages(self):
|
||||
for message in self._messages:
|
||||
yield message
|
||||
|
||||
|
||||
def test_language_from_tokens_uses_single_recognized_language():
|
||||
tokens = [
|
||||
{"text": "Hello", "language": "en"},
|
||||
{"text": " world", "language": "en"},
|
||||
]
|
||||
|
||||
assert _language_from_tokens(tokens) == Language.EN
|
||||
|
||||
|
||||
def test_language_from_tokens_uses_latest_language():
|
||||
tokens = [
|
||||
{"text": "Hallo", "language": "nl"},
|
||||
{"text": " world", "language": "en"},
|
||||
]
|
||||
|
||||
assert _language_from_tokens(tokens) == Language.EN
|
||||
|
||||
|
||||
def test_language_from_tokens_skips_unknown_latest_language():
|
||||
tokens = [
|
||||
{"text": " world", "language": "en"},
|
||||
{"text": "!", "language": "klingon"},
|
||||
]
|
||||
|
||||
assert _language_from_tokens(tokens) == Language.EN
|
||||
|
||||
|
||||
def test_language_from_tokens_skips_missing_latest_language():
|
||||
tokens = [
|
||||
{"text": "Hello", "language": "en"},
|
||||
{"text": " wereld"},
|
||||
]
|
||||
|
||||
assert _language_from_tokens(tokens) == Language.EN
|
||||
|
||||
|
||||
def test_language_from_tokens_ignores_unknown_and_missing_languages():
|
||||
tokens = [
|
||||
{"text": "Hello", "language": "klingon"},
|
||||
{"text": " world"},
|
||||
{"text": "!"},
|
||||
]
|
||||
|
||||
assert _language_from_tokens(tokens) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_messages_sets_final_transcription_language(monkeypatch):
|
||||
service = SonioxSTTService(api_key="test-key")
|
||||
pushed_frames = []
|
||||
traced_transcriptions = []
|
||||
|
||||
async def fake_push_frame(frame):
|
||||
pushed_frames.append(frame)
|
||||
|
||||
async def fake_handle_transcription(transcript, is_final, language=None):
|
||||
traced_transcriptions.append((transcript, is_final, language))
|
||||
|
||||
async def fake_stop_processing_metrics():
|
||||
pass
|
||||
|
||||
messages = [
|
||||
json.dumps(
|
||||
{
|
||||
"tokens": [
|
||||
{"text": "Hello", "is_final": True, "language": "en"},
|
||||
{"text": " world", "is_final": True, "language": "en"},
|
||||
{"text": END_TOKEN, "is_final": True},
|
||||
]
|
||||
}
|
||||
),
|
||||
json.dumps({"tokens": [], "finished": True}),
|
||||
]
|
||||
|
||||
service._websocket = _FakeWebsocket(messages)
|
||||
monkeypatch.setattr(service, "push_frame", fake_push_frame)
|
||||
monkeypatch.setattr(service, "_handle_transcription", fake_handle_transcription)
|
||||
monkeypatch.setattr(service, "stop_processing_metrics", fake_stop_processing_metrics)
|
||||
|
||||
await service._receive_messages()
|
||||
|
||||
final_frames = [frame for frame in pushed_frames if isinstance(frame, TranscriptionFrame)]
|
||||
assert len(final_frames) == 1
|
||||
assert final_frames[0].text == "Hello world"
|
||||
assert final_frames[0].language == Language.EN
|
||||
assert final_frames[0].finalized is True
|
||||
assert final_frames[0].result == [
|
||||
{"text": "Hello", "is_final": True, "language": "en"},
|
||||
{"text": " world", "is_final": True, "language": "en"},
|
||||
]
|
||||
assert traced_transcriptions == [("Hello world", True, Language.EN)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_messages_allows_final_transcription_without_language(monkeypatch):
|
||||
service = SonioxSTTService(api_key="test-key")
|
||||
pushed_frames = []
|
||||
traced_transcriptions = []
|
||||
|
||||
async def fake_push_frame(frame):
|
||||
pushed_frames.append(frame)
|
||||
|
||||
async def fake_handle_transcription(transcript, is_final, language=None):
|
||||
traced_transcriptions.append((transcript, is_final, language))
|
||||
|
||||
async def fake_stop_processing_metrics():
|
||||
pass
|
||||
|
||||
messages = [
|
||||
json.dumps(
|
||||
{
|
||||
"tokens": [
|
||||
{"text": "Tell", "is_final": True},
|
||||
{"text": " me", "is_final": True},
|
||||
{"text": " a", "is_final": True},
|
||||
{"text": " joke.", "is_final": True},
|
||||
{"text": END_TOKEN, "is_final": True},
|
||||
]
|
||||
}
|
||||
),
|
||||
json.dumps({"tokens": [], "finished": True}),
|
||||
]
|
||||
|
||||
service._websocket = _FakeWebsocket(messages)
|
||||
monkeypatch.setattr(service, "push_frame", fake_push_frame)
|
||||
monkeypatch.setattr(service, "_handle_transcription", fake_handle_transcription)
|
||||
monkeypatch.setattr(service, "stop_processing_metrics", fake_stop_processing_metrics)
|
||||
|
||||
await service._receive_messages()
|
||||
|
||||
final_frames = [frame for frame in pushed_frames if isinstance(frame, TranscriptionFrame)]
|
||||
assert len(final_frames) == 1
|
||||
assert final_frames[0].text == "Tell me a joke."
|
||||
assert final_frames[0].language is None
|
||||
assert final_frames[0].finalized is True
|
||||
assert traced_transcriptions == [("Tell me a joke.", True, None)]
|
||||
Reference in New Issue
Block a user