diff --git a/engine/config.py b/engine/config.py index d8cc655..913d103 100644 --- a/engine/config.py +++ b/engine/config.py @@ -91,6 +91,7 @@ class STTConfig: encoding: str = "raw" frame_size: int = 1280 timeout_sec: float = 10.0 + dynamic_correction: bool = False @dataclass(frozen=True) diff --git a/engine/services.py b/engine/services.py index 046de91..b322be6 100644 --- a/engine/services.py +++ b/engine/services.py @@ -32,6 +32,7 @@ def create_stt_service(config: STTConfig, audio: AudioConfig | None = None): encoding=config.encoding, frame_size=config.frame_size, open_timeout=config.timeout_sec, + dynamic_correction=config.dynamic_correction, ) _require_provider(config.provider, "openai", "stt") diff --git a/engine/xfyun_asr.py b/engine/xfyun_asr.py index 81c8b5a..939e5c7 100644 --- a/engine/xfyun_asr.py +++ b/engine/xfyun_asr.py @@ -17,7 +17,6 @@ from loguru import logger from pipecat.frames.frames import ( CancelFrame, EndFrame, - ErrorFrame, Frame, InterimTranscriptionFrame, TranscriptionFrame, @@ -53,6 +52,7 @@ class XfyunASRService(STTService): encoding: str = "raw", frame_size: int = 1280, open_timeout: float = 10.0, + dynamic_correction: bool = False, **kwargs, ) -> None: super().__init__( @@ -70,6 +70,7 @@ class XfyunASRService(STTService): self._encoding = encoding self._frame_size = frame_size self._open_timeout = open_timeout + self._dynamic_correction = dynamic_correction self._websocket = None self._receive_task = None @@ -79,11 +80,6 @@ class XfyunASRService(STTService): self._finalizing_turn = False self._partials: list[str] = [] self._last_text = "" - # Text already emitted as TranscriptionFrame deltas for Pipecat's - # user-turn strategy and context aggregator. The xfyun websocket now - # spans a full logical user turn, so UI interim frames can use - # `_last_text` directly while this cursor prevents duplicate context. - self._turn_transcription_text = "" async def cleanup(self) -> None: await self._close_utterance() @@ -135,7 +131,6 @@ class XfyunASRService(STTService): self._audio_buffer.clear() self._partials = [] self._last_text = "" - self._turn_transcription_text = "" self._sent_first_frame = False self._sent_final_frame = False @@ -206,13 +201,17 @@ class XfyunASRService(STTService): return if not self._sent_first_frame: + business = { + "language": self._language, + "domain": self._domain, + "accent": self._accent, + } + if self._dynamic_correction: + business["dwa"] = "wpgs" + payload = { "common": {"app_id": self._app_id}, - "business": { - "language": self._language, - "domain": self._domain, - "accent": self._accent, - }, + "business": business, "data": { "status": 0, "format": f"audio/L16;rate={self.sample_rate}", @@ -266,12 +265,13 @@ class XfyunASRService(STTService): if not isinstance(data, dict): return + is_final_response = data.get("status") == 2 recognition = data.get("result") if isinstance(recognition, dict): text = self._apply_recognition_result(recognition) if text and text != self._last_text: self._last_text = text - if not self._finalizing_turn: + if not self._finalizing_turn and not is_final_response: await self.push_frame( InterimTranscriptionFrame( text, @@ -281,51 +281,28 @@ class XfyunASRService(STTService): result=payload, ) ) - await self._push_transcription_delta(text, result=payload, finalized=False) - if data.get("status") == 2: + if is_final_response: final_text = self._last_text - if final_text and not self._finalizing_turn: + if final_text: self.confirm_finalize() - await self._push_transcription_delta(final_text, result=payload, finalized=True) + await self.push_frame( + TranscriptionFrame( + final_text, + self._user_id, + time_now_iso8601(), + _language_or_none(self._language), + result=payload, + ) + ) await self._close_utterance() - async def _push_transcription_delta( - self, - text: str, - *, - result: dict[str, Any], - finalized: bool, - ) -> None: - if text.startswith(self._turn_transcription_text): - delta = text[len(self._turn_transcription_text) :] - else: - logger.debug( - "Xfyun transcript replacement is not append-only; " - "continuing with the new suffix for turn aggregation" - ) - delta = text - - if not delta and not finalized: - return - - self._turn_transcription_text = text - await self.push_frame( - TranscriptionFrame( - delta, - self._user_id, - time_now_iso8601(), - _language_or_none(self._language), - result=result, - ) - ) - def _apply_recognition_result(self, recognition: dict[str, Any]) -> str: partial = _extract_text_from_result(recognition) if not partial: return self._last_text - if recognition.get("pgs") == "rpl" and recognition.get("rg"): + if self._dynamic_correction and recognition.get("pgs") == "rpl" and recognition.get("rg"): start, end = recognition["rg"] if 1 <= start <= len(self._partials): self._partials[start - 1 : end] = [partial]